目录
一.引言
二.数据准备
三.随机采样 Sample
四.按权重拆分 randomSplit
五.分层采样 sampleByKey
六.总结
一.引言
使用 Spark 进行机器学习、数据分析等项目时,常常需要对数据进行采样,下面介绍三种最常用的采样方法:
A.随机采样: 适合从原始数据随机筛选一部分数据
B.按比例划分采样:适合拆分训练、测试样本
C.分层采样:适合根据不同需求对不同类型样本加权
二.数据准备
- 数据样式
采样前,首先模拟一批正负比 1:4 的正负样本,其中除了 label 外,还有 group 特征区分样本所属分组,共分为 A、B、C 三个组,比例控制在 1:1:1,累计 10000 条样本。
val spark = SparkSession
.builder //创建spark会话
.master("local") //设置本地模式
.appName("SampleUtil") //设置名称
.getOrCreate() //创建会话变量
spark.sparkContext.setLogLevel("error")
val random = new scala.util.Random()
random.setSeed(999)
val allGroup = Array("A", "B", "C")
import spark.implicits._
// 模拟数据,随机分三个组,正负样本 1:4
val testData = (0 to 10000).map(num => {
val group = allGroup(random.nextInt(3))
val label = if (random.nextDouble() > 0.2) {
0
} else {
1
}
val features = Seq(random.nextInt(10), random.nextInt(10), random.nextInt(10))
(label, group, features)
}).toDF("label", "uGroup", "features")
testData.printSchema()
- 数据统计
上面生成了原始数据的 DataFrame,下面使用 SparkSql 统计下 Label 与 Group 的分布信息。
val testTable = "TestTable"
testData.createOrReplaceTempView(testTable)
showMetrics(testTable, spark)
其中 showMetrics 为两条 sql 语句:
def showMetrics(tableName: String, sparkSession: SparkSession): Unit = {
// 正负样本统计
sparkSession.sql(s"select label,count(*) as cnt from $tableName group by label").show()
// 组统计
sparkSession.sql(s"select uGroup,count(*) as cnt from $tableName group by uGroup order by uGroup").show()
}
三.随机采样 Sample
// 1.随机采样
val randomSampleData = testData.sample(withReplacement=false, fraction = 0.5, seed = 999)
val randomTable = "RandomSampleTable"
randomSampleData.createOrReplaceTempView(randomTable)
showMetrics(randomTable, spark)
随机采样主要有三个参数:
withReplacement - 是否放回,True 有放回情况下一条样本可能会多次抽中
fraction - 采样比例
seed - 随机种子,不同 seed 采样结果不同
fraction = 0.5,所以会保留 50% 的数据,可以看到采样后数据量减半,但是整体比例不会受影响。
四.按权重拆分 randomSplit
// 2.正负样本划分 randomSplit
val fraction = Array(0.8, 0.1, 0.1)
val dataSpilt = testData.randomSplit(fraction)
val (train, test, valid) = (dataSpilt(0), dataSpilt(1), dataSpilt(2))
println(s"Train: ${train.count()} Test: ${test.count()} Valid: ${valid.count()}")
randomSplit 方法根据传入的 Array[Ratio*] 分组比例数组对原始样本进行拆分,上述代码按照 8:1:1 拆分训练集、测试集、验证集。
Train: 7949 Test: 1025 Valid: 1027
五.分层采样 sampleByKey
分层采样需要使用 keyBy 生成 pairRDD,通过指定 key 的采样率实现分层采样,这里采样比例 Map 为 <T, Double> 的形式,其中 T 为对应 pairRDD 中 key 的形式,下述代码实现了保留全部正样本,并随机挑选一半负样本。
// 3.分层采样
val keyByData = testData.rdd.keyBy(_.getInt(0))
val sampleRatio = Map(0 -> 0.5, 1 -> 1.0)
val addWeight = keyByData.sampleByKey(withReplacement = false, fractions = sampleRatio).map(_._2)
由于 sampleByKey 后获取的是 RDD<T, Row> 的形式,为了进行 spark sql 统计需要将 RDD 转换为 DataFrame:
val schema =
StructType(
StructField("label", IntegerType, false) ::
StructField("uGroup", StringType, false) ::
StructField("features", ArrayType(IntegerType), false) :: Nil
)
println(schema)
val name = "addWeightTable"
addWeightTable.createOrReplaceTempView(name)
showMetrics(name, spark)
Tips:
这里 RDD 转 DF 可以不使用 schema,此时字段变为 _1、_2 ... 为了继续使用 showMetrics 方法,所以指定 schema,除了上面手动指定外,也可以使用 case class 进行推理:
case class User(label: Int, uGroup: String, features: Array[Int])
import org.apache.spark.sql.catalyst.ScalaReflection
val scalaSchema = ScalaReflection.schemaFor[User].dataType.asInstanceOf[StructType]
val encoderSchema = Encoders.product[User].schema
println(encoderSchema)
上面为手动定义的 schema,下面为推理得到的 schema,主要差别在 nullable 参数。
六.总结
randomSplit 按权重拆分训练集、测试集以及 randomByKey 对样本进行分层采样,例如上采样可以使用上述方法多次采样保留更多正样本,也可以指定不同 seed 实现 Bagging 采样。