问题一
背景:
本题目基于用户数据,将据数据切分为训练集和验证集,供建模使用。训练集与测试集切分比例为8:2。
数据说明:
capter5_2ml.csv中每列数据分别为userId , movieId , rating , timestamp。
数据:
capter5_2ml.csv
题目:
使用Spark MLlib中的使用ALS算法给每个用户推荐某个商品。:
要求:
①设置迭代次数为5次,惩罚系数为0.01,得到评分的矩阵形式(2分)。
②对模型进行拟合,训练出合适的模型(2分)。
③为一组指定的用户生成十大电影推荐(4分)。
④生成前十名用户推荐的一组指定的电影(4分)。
⑤对结果进行正确输出(1分)。
代码:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.*;
public class MovieRecommender {
public static void main(String[] args) {
// Step 1: 初始化Spark会话
SparkSession spark = SparkSession.builder()
.appName("MovieRecommender")
.master("local[*]") // 本地模式运行
.getOrCreate();
// Step 2: 读取数据
Dataset<Row> data = spark.read()
.option("header", "true")
.csv("capter5_2ml.csv");
// 数据类型转换(userId、movieId、rating)
data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType))
.withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType))
.withColumn("rating", data.col("rating").cast(DataTypes.FloatType));
// 显示数据的前5行
data.show(5);
// Step 3: 数据划分为训练集和测试集,比例为8:2
Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];
// Step 4: 使用ALS模型训练
ALS als = new ALS()
.setMaxIter(5) // 设置迭代次数
.setRegParam(0.01) // 设置正则化参数
.setUserCol("userId") // 用户ID列
.setItemCol("movieId") // 物品ID列
.setRatingCol("rating") // 评分列
.setColdStartStrategy("drop"); // 丢弃冷启动数据
// 模型拟合,训练模型
ALSModel model = als.fit(training);
// Step 5: 模型评估
Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
// Step 6: 为每个用户生成前10个电影推荐
Dataset<Row> userRecs = model.recommendForAllUsers(10);
userRecs.show(10, false); // 展示每个用户推荐的前10部电影
// Step 7: 为前10名用户推荐指定电影
Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);
Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);
topUserRecs.show(false); // 展示前10个用户推荐的电影列表
// Step 8: 输出推荐的电影和评分
userRecs.select(col("userId"), explode(col("recommendations")).as("rec"))
.select(col("userId"), col("rec.movieId"), col("rec.rating"))
.show(false);
// 关闭Spark会话
spark.stop();
}
}
初始化Spark会话
SparkSession spark = SparkSession.builder()
.appName("MovieRecommender")
.master("local[*]") // 本地模式运行
.getOrCreate();
- SparkSession:是Spark 2.0之后推荐使用的上下文对象,代替了旧版本的SQLContext和HiveContext。通过SparkSession.builder()来创建Spark会话。
- .appName("MovieRecommender"):指定应用名称,方便在Spark UI中识别。
- .master("local[*]"):指定运行模式为本地模式(local[*]表示使用所有可用的CPU核心)。
- .getOrCreate():创建或获取现有的SparkSession。
2. 读取数据
Dataset<Row> data = spark.read()
.option("header", "true")
.csv("capter5_2ml.csv");
- spark.read():使用Spark的DataFrame API来读取数据。
- .option("header", "true"):指定文件的第一行是表头,这样可以自动识别列名。
- .csv("capter5_2ml.csv"):读取CSV文件,创建一个Dataset<Row>对象(类似于DataFrame)。
3. 转换数据类型
data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType))
.withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType))
.withColumn("rating", data.col("rating").cast(DataTypes.FloatType));
- withColumn():用来创建新的列或修改已有列的值。
- .cast(DataTypes.IntegerType):将userId和movieId列的数据类型转换为整数。
- .cast(DataTypes.FloatType):将rating列的数据类型转换为浮点数。
4. 数据展示
data.show(5);
- show(5):显示前5行数据,方便确认数据读取和转换是否正确。
5. 数据集划分为训练集和测试集
Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];
- randomSplit(new double[]{0.8, 0.2}):将数据随机划分为两个部分,80%用于训练,20%用于测试。
- splits[0]:获取训练集。
- splits[1]:获取测试集。
6. 使用ALS模型训练
ALS als = new ALS()
.setMaxIter(5) // 设置迭代次数
.setRegParam(0.01) // 设置正则化参数
.setUserCol("userId") // 用户ID列
.setItemCol("movieId") // 物品ID列
.setRatingCol("rating") // 评分列
.setColdStartStrategy("drop"); // 丢弃冷启动数据
- ALS():ALS(交替最小二乘法)是Spark MLlib用于协同过滤推荐系统的算法。
- setMaxIter(5):设置最大迭代次数为5次。
- setRegParam(0.01):设置正则化参数,防止过拟合。
- setUserCol("userId"):指定用户列。
- setItemCol("movieId"):指定物品列(电影)。
- setRatingCol("rating"):指定评分列。
- setColdStartStrategy("drop"):如果在预测时遇到冷启动问题(没有数据的用户或电影),则丢弃这些结果。
7. 模型拟合(训练)
ALSModel model = als.fit(training);
- als.fit(training):在训练集上训练ALS模型,返回一个ALSModel对象。
8. 模型评估
Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = " + rmse);
- model.transform(test):使用训练好的模型对测试集进行预测,返回预测结果。
- RegressionEvaluator:回归模型评估器,用于计算预测误差。
- setMetricName("rmse"):指定使用均方根误差(RMSE)作为评估指标。
- evaluate(predictions):对预测结果进行评估,计算RMSE值。
9. 为每个用户生成前10个电影推荐
Dataset<Row> userRecs = model.recommendForAllUsers(10);
userRecs.show(10, false);
- recommendForAllUsers(10):为每个用户生成前10个电影推荐,结果存储在userRecs中。
- show(10, false):显示前10个用户的推荐列表。
10. 为前10名用户生成推荐的电影
Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);
Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);
topUserRecs.show(false);
- select("userId").distinct():选择唯一的userId,获取不重复的用户。
- limit(10):只选择前10个用户。
- recommendForUserSubset(topUsers, 10):为前10个用户生成推荐的电影列表。
11. 输出推荐的电影和评分
userRecs.select(col("userId"), explode(col("recommendations")).as("rec"))
.select(col("userId"), col("rec.movieId"), col("rec.rating"))
.show(false);
- explode(col("recommendations")):展开推荐的电影列表(每个用户的推荐电影是一个数组,explode将其展开为多行)。
- select(col("userId"), col("rec.movieId"), col("rec.rating")):选择用户ID、电影ID和评分列进行展示。
- show(false):显示完整数据。
12. 关闭Spark会话
spark.stop();
- spark.stop():关闭Spark会话,释放资源。
问题二
背景:银行需要根据贷款用户的数据信息预测其是否有违约的可能,并对违约的可能性进行预测。对于银行业或者小贷机构而言,信用卡以及信贷服务是高风险和高收益的业务,如何通过用户的海量数据挖掘出用户潜在的信息即信用評分,并参与审批业务的决策从而提高了风险防控措施,该过程不仅提高了业务的审批效率而且给予了关键的决策,同时风险防控如果没有监测到位,对于银行业来说会造成不可估量的损失,因此这部分的工作是至关重要的。
本题目基于某贷款用户行为数据,将提供训练集和验证集供建模使用。
数据说明:
数据:
train.csv训练集
test.csv 测试集
测试集test.csv相比训练集只是少了一列Label,是需要我们去建模预测的
参考代码:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
public class LoanDefaultPrediction {
public static void main(String[] args) {
// 初始化 Spark 会话
SparkSession spark = SparkSession.builder()
.appName("LoanDefaultPrediction")
.master("local[*]") // 本地模式运行
.getOrCreate();
// 读取训练数据
Dataset<Row> trainData = spark.read()
.option("header", "true")
.option("inferSchema", "true") // 自动推断数据类型
.csv("train.csv");
// 选择特征列进行训练 (排除label)
String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};
// 将特征列汇总为单一向量
VectorAssembler assembler = new VectorAssembler()
.setInputCols(featureColumns)
.setOutputCol("features");
// 将训练数据中的特征列组装成特征向量
Dataset<Row> trainWithFeatures = assembler.transform(trainData);
// 逻辑回归模型
LogisticRegression lr = new LogisticRegression()
.setLabelCol("label") // 目标列
.setFeaturesCol("features"); // 特征向量列
// 训练模型
lr.fit(trainWithFeatures);
// 关闭 Spark 会话
spark.stop();
}
}
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.SparkSession;
public class LoanDefaultPredictionTest {
public static void main(String[] args) {
// 初始化 Spark 会话
SparkSession spark = SparkSession.builder()
.appName("LoanDefaultPrediction")
.master("local[*]") // 本地模式运行
.getOrCreate();
// 读取测试数据(没有label列)
Dataset<Row> testData = spark.read()
.option("header", "true")
.option("inferSchema", "true") // 自动推断数据类型
.csv("test.csv");
// 选择特征列
String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};
// 将特征列汇总为单一向量
VectorAssembler assembler = new VectorAssembler()
.setInputCols(featureColumns)
.setOutputCol("features");
// 将测试数据中的特征列组装成特征向量
Dataset<Row> testWithFeatures = assembler.transform(testData);
// 加载之前训练好的逻辑回归模型
LogisticRegressionModel model = LogisticRegressionModel.load("path_to_saved_model");
// 使用模型对测试数据进行预测
Dataset<Row> predictions = model.transform(testWithFeatures);
// 展示预测结果
predictions.select("id", "prediction").show();
// 模型评估
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy");
double accuracy = evaluator.evaluate(predictions);
System.out.println("Test set accuracy = " + accuracy);
// 关闭 Spark 会话
spark.stop();
}
}
问题三
客户流失已成为每个希望提高品牌忠诚度的公司重点关注的问题,本题目基于某电信公司流失客户数据集,将提供训练集和验证集供建模使用,请回复验证集数据的模型计算结果文件和建模过程文档。
- 回复 当前模型的查全率和查准率分别是多少,数据描述
- 回复结果文件要求
结果文件请以逗号分隔符文本文件提供,包含以下字段:
- 用户标志
- 预测是否进入异常状态(1:异常状态;0-非异常状态)
- 建模过程文 档请以WORD文档形式提供,需要详细列出数据探索过程和建模思路。
数据说明:
字段名称 | 字段类型 | 中文名称和注释 |
USER_ID | VARCHAR(16) | 用户标志(两文件里的用户标志没有关联性) |
FLOW | DECIMAL(16) | 当月流量(Byte) |
FLOW_LAST_ONE | DECIMAL(16) | 上一月流量(Byte) |
FLOW_LAST_TWO | DECIMAL(16) | 上两个月流量(Byte) |
MONTH_FEE | DECIMAL(18,2) | 当月收入(元) |
MONTHS_3AVG | DECIMAL(18,2) | 最近3个月平均收入(元) |
BINDEXP_DATE | DATE | 绑定到期时间 |
PHONE_CHANGE | INTEGER | 当月是否更换终端 |
AGE | INTEGER | 年龄 |
OPEN_DATE | DATE | 开户时间 |
REMOVE_TAG | CHARACTER(1) | 用户状态(‘A’:正常,其他异常)(验证集中不提供此字段) |
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.*;
public class CustomerChurnPrediction {
public static void main(String[] args) {
// 初始化 SparkSession
SparkSession spark = SparkSession.builder()
.appName("CustomerChurnPrediction")
.master("local[*]")
.getOrCreate();
// 读取训练数据
Dataset<Row> trainData = spark.read()
.option("header", "true")
.option("inferSchema", "true")
.csv("train.csv");
// 数据预处理
// 转换日期为数值特征
trainData = trainData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE")))
.withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));
// 去除日期列(已转化为数值特征)
trainData = trainData.drop("BINDEXP_DATE", "OPEN_DATE");
// 标签处理,将REMOVE_TAG 'A' 转化为 0,其他为 1
trainData = trainData.withColumn("label", when(col("REMOVE_TAG").equalTo("A"), 0).otherwise(1));
// 特征列
String[] featureColumns = new String[]{"FLOW", "FLOW_LAST_ONE", "FLOW_LAST_TWO", "MONTH_FEE", "MONTHS_3AVG",
"PHONE_CHANGE", "AGE", "days_until_bind_exp", "days_since_open"};
// 特征向量组装
VectorAssembler assembler = new VectorAssembler()
.setInputCols(featureColumns)
.setOutputCol("features");
// 将特征列向量化
Dataset<Row> trainWithFeatures = assembler.transform(trainData);
// 训练逻辑回归模型
LogisticRegression lr = new LogisticRegression()
.setLabelCol("label")
.setFeaturesCol("features");
// 模型拟合
LogisticRegressionModel model = lr.fit(trainWithFeatures);
// 读取验证集数据(没有label)
Dataset<Row> testData = spark.read()
.option("header", "true")
.option("inferSchema", "true")
.csv("test.csv");
// 转换验证集中的日期为数值特征
testData = testData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE")))
.withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));
// 移除多余列
testData = testData.drop("BINDEXP_DATE", "OPEN_DATE");
// 将验证集特征向量化
Dataset<Row> testWithFeatures = assembler.transform(testData);
// 使用模型对验证集进行预测
Dataset<Row> predictions = model.transform(testWithFeatures);
// 输出查准率和查全率
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setMetricName("areaUnderROC");
double accuracy = evaluator.evaluate(predictions);
System.out.println("Model Accuracy = " + accuracy);
// 提取需要的列并输出到文件
predictions.select("USER_ID", "prediction")
.write()
.option("header", "true")
.csv("prediction_results.csv");
// 关闭 SparkSession
spark.stop();
}
}
1. 数据理解
- 任务背景:本次任务的目的是预测电信公司的客户是否会进入异常状态。给定的数据包括用户的流量、收入、终端更换情况、年龄、绑定到期时间等特征,并通过历史数据中的“用户状态”来训练模型。
- 数据集:提供了训练集和验证集两个数据集,训练集中含有目标变量(用户是否进入异常状态),而验证集中只包含特征数据,需要我们预测其异常状态。
- 数据字段解释:
- USER_ID:用户的唯一标识。
- FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO:用户当月及前两个月的流量数据。
- MONTH_FEE、MONTHS_3AVG:用户当月及最近三个月的平均收入。
- BINDEXP_DATE、OPEN_DATE:用户的绑定到期时间和开户时间。
- PHONE_CHANGE:用户是否更换终端。
- AGE:用户年龄。
- REMOVE_TAG:训练集中提供的用户状态标签(‘A’表示正常,其他表示异常)。
2. 数据探索和可视化
- 统计描述:对训练集中的数值型字段(如FLOW、MONTH_FEE、AGE等)进行统计描述,计算其最大值、最小值、均值、标准差等,帮助我们了解数据分布。
- 最大值、最小值:了解数据的范围,判断是否存在异常值。
- 缺失值:检查是否有缺失数据,对有缺失值的字段,考虑填充或删除。
- 特征分布分析:分析特征的分布,尤其是与目标变量(REMOVE_TAG)之间的关系。
- 绘制流量、收入、年龄等特征的分布图,观察是否存在显著差异。
3. 数据预处理
- 日期处理:将日期类型的特征BINDEXP_DATE和OPEN_DATE转换为数值型特征。比如,将它们转换为距离当前日期的天数,以便模型能理解时间间隔对用户状态的影响。
- 类别变量处理:对于PHONE_CHANGE等离散类别变量,直接使用数值型(如0或1)表示是否更换终端。
- 标签处理:将训练集中的REMOVE_TAG字段进行二元分类处理。A表示正常用户,转换为0,其他值表示异常用户,转换为1。
- 特征归一化:由于不同特征的取值范围可能差别较大(如流量和收入单位不同),我们对数值特征进行标准化或归一化,以提高模型的训练效果。
- 归一化后的特征将有助于模型更加高效地收敛。
4. 特征工程
- 特征选择:在特征工程中选择与目标变量相关的特征。根据数据探索的结果,我们使用了以下特征:
- 流量特征:FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO
- 收入特征:MONTH_FEE、MONTHS_3AVG
- 终端更换:PHONE_CHANGE
- 时间特征:days_until_bind_exp(距离绑定到期的天数)、days_since_open(距离开户的天数)
- 用户年龄:AGE
- 特征向量化:在模型训练中,我们将这些特征通过VectorAssembler进行特征向量化,方便输入到机器学习模型中。
5. 模型选择与训练
- 模型选择:基于当前任务是一个二元分类问题,我们尝试了以下分类模型:
- 逻辑回归(Logistic Regression):作为一个经典的线性模型,逻辑回归能够很好地处理二元分类问题。
- 随机森林(Random Forest):能够处理高维数据并且具有良好的泛化能力。
- 梯度提升树(Gradient Boosting Tree):能够通过迭代的方式进行优化,处理非线性关系。
- 交叉验证:通过交叉验证(Cross Validation)对模型的超参数进行调优。最终,我们选择了表现最好的模型(例如逻辑回归),并使用其对验证集进行预测。
- 超参数优化:对逻辑回归的正则化参数(regParam)和最大迭代次数(maxIter)进行调参。最终选择了合适的参数组合。
6. 模型评估
- 模型性能指标:
- 查准率(Precision):表示模型预测出的正样本中有多少是实际的正样本。
- 查全率(Recall):表示实际的正样本中有多少被模型正确识别为正样本。
- F1-score:查准率和查全率的调和平均数,用来综合评估模型性能。
- 混淆矩阵:通过混淆矩阵展示模型在测试集上的表现,查看真阳性(True Positive, TP)、假阳性(False Positive, FP)、真阴性(True Negative, TN)和假阴性(False Negative, FN)的数量,并由此计算出准确率、查准率和查全率。
- ROC曲线与AUC:使用ROC曲线和AUC值评估模型的分类能力。AUC值越接近1,表示模型性能越好。
7. 模型预测与输出
- 验证集预测:将验证集通过训练好的模型进行预测,生成每个用户的异常状态预测结果。
- 结果输出:生成预测结果文件,包含用户标志和预测结果:
- USER_ID:用户标志。
- PREDICTION:预测结果,1 表示异常,0 表示正常。
结果文件以逗号分隔的CSV格式输出。
8. 结论与改进方向
- 结论:当前模型能够较为准确地预测用户是否进入异常状态,但在某些特定情况下(如数据不平衡时),可能会导致查全率偏低。通过调优参数或采用其他更复杂的模型(如XGBoost),有望进一步提升模型的性能。
- 改进方向:
- 处理数据不平衡问题(通过下采样或上采样)。
- 尝试更多高级的模型,如XGBoost或深度学习模型。
- 增加特征工程部分,例如对流量特征进行更多的交互处理或聚类分析。
Spark MLlib 完整代码总结
以下是使用 Spark MLlib 进行机器学习任务的完整流程代码总结。代码包含从数据预处理、特征工程到模型训练、评估和预测的各个步骤,适用于处理典型的二分类问题。
1. 引入依赖
在编写 Spark 应用时,首先需要引入所需的依赖包和库:
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.*;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.feature.*;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.linalg.Vector;
2. 创建 Spark 会话
SparkSession spark = SparkSession.builder()
.appName("Spark ML Example")
.master("local[*]") // 可以根据环境修改
.getOrCreate();
3. 加载数据
假设我们有两个数据集:训练集 train.csv 和测试集 test.csv,需要进行数据加载:
Dataset<Row> trainData = spark.read().option("header", "true")
.option("inferSchema", "true") // 推断数据类型
.csv("path/to/train.csv");
Dataset<Row> testData = spark.read().option("header", "true")
.option("inferSchema", "true")
.csv("path/to/test.csv");
4. 数据预处理
4.1 处理缺失值
// 对数值列进行缺失值处理(比如用平均值填充)
trainData = trainData.na().fill(0); // 或者 .fill(“默认值”)
4.2 将标签转化为数值
// 假设标签列为 “label”
StringIndexer labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(trainData);
4.3 特征处理:向量组装
// 将所有特征列转化为特征向量
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"feature1", "feature2", "feature3"}) // 根据实际特征列名修改
.setOutputCol("features");
trainData = assembler.transform(trainData);
testData = assembler.transform(testData);
5. 模型训练
5.1 逻辑回归模型
LogisticRegression lr = new LogisticRegression()
.setLabelCol("indexedLabel") // 标签列
.setFeaturesCol("features") // 特征列
.setMaxIter(10)
.setRegParam(0.01);
5.2 构建管道
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{labelIndexer, assembler, lr});
5.3 交叉验证与模型调优
// 参数网格调优
ParamGridBuilder paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
.addGrid(lr.maxIter(), new int[]{10, 20});
// 二分类评估器
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
.setLabelCol("indexedLabel");
// 交叉验证
CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid.build())
.setNumFolds(5);
CrossValidatorModel cvModel = cv.fit(trainData);
6. 模型评估
6.1 在测试集上进行预测
Dataset<Row> predictions = cvModel.transform(testData);
// 显示前几条预测结果
predictions.select("user_id", "prediction", "probability").show(5);
6.2 计算模型的性能指标
// 评估 AUC(Area Under ROC Curve)
double auc = evaluator.evaluate(predictions);
System.out.println("AUC: " + auc);
// 混淆矩阵
MulticlassClassificationEvaluator multiEval = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setMetricName("accuracy");
double accuracy = multiEval.evaluate(predictions);
System.out.println("Test Accuracy = " + accuracy);
7. 结果输出
将预测结果导出为 CSV 文件:
// 只导出用户ID和预测结果
predictions.select("user_id", "prediction")
.write()
.option("header", "true")
.csv("path/to/output_predictions.csv");
8. 模型持久化
将模型保存以供未来使用:
cvModel.write().overwrite().save("path/to/saved_model");
9. 模型加载(如有需要)
如果需要加载保存的模型以便进行预测:
CrossValidatorModel loadedModel = CrossValidatorModel.load("path/to/saved_model");
// 使用加载的模型进行预测
Dataset<Row> newPredictions = loadedModel.transform(newData);
10. 建模总结
- 数据探索:
- 对数据进行缺失值处理和简单的统计描述分析。
- 特征和标签的处理,确保数据符合模型的输入要求。
- 特征工程:
- 将数值列转化为向量格式。
- 使用StringIndexer对标签列进行编码处理。
- 模型训练:
- 选择逻辑回归作为初始模型。
- 使用交叉验证和参数网格进行模型调优。
- 模型评估:
- 使用 AUC 和准确率作为评估指标。
- 根据模型的性能优化超参数。
- 结果输出:
- 输出预测结果到文件,并保存模型以供未来预测使用。