文章目录
- 一 Redis_Jedis_测试
- 1 Jedis所需要的jar包
- 2 连接Redis注意事项
- 3 测试相关数据类型
- (0)测试连接
- (1)Key
- (2)String
- (3)List
- (4)set
- (5)hash
- (6)zset
- (7)连接池
- 二 将人群包存放到Redis中
- 1 思路分析
- 2 代码实现
- UserGroupMapper.java
- UserGroupServiceImpl.java
- 3 结果验证
- 三 挖掘类标签
- 1 挖掘类标签与机器学习
- 2 模型建立过程
- (1)数据预处理
- (2)特征工程
- (3)算法选取
- (4)模型训练
- (5)评估优化
- (6)投放使用、验证优化
- 四 决策树算法
- 五 用SparkMLlib实现决策树的使用
- 1 利用流水线完成模型训练
- 2 数据准备
- (1)建表语句
- (2)模拟数据
- 3 模块搭建
- 4 创建流水线对象
- 5 增加流水线组件(三个徒弟,一个师傅)
- (1)创建标签索引
- (2)创建特征集合
- (3)创建特征向量索引
- (4)创建分类器
- 6 初始化对象
- 7 训练和预测
- 8 调用工具类,进行测试
- 六 完整代码
- 1 MyPipeline
- 2 StudentGenderTrain
一 Redis_Jedis_测试
1 Jedis所需要的jar包
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>3.2.0</version>
</dependency>
2 连接Redis注意事项
禁用Linux的防火墙:Linux(CentOS7)里执行命令
systemctl stop/disable firewalld.service
redis.conf中注释掉
bind 127.0.0.1
然后将安全模式关闭
protected-mode no
或者不关闭安全模式,设置密码在配置文件中的requirepass中修改
然后执行命令前需要先输入密码
AUTH 密码 之后再执行命令
新建工具类 user_profile_manager_0224\src\main\java\com\atguigu\userprofile\utils\RedisUtil.java
3 测试相关数据类型
(0)测试连接
public class RedisUtil {
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
}
}
(1)Key
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.set("k1000","v1000");
jedis.set("k2000","v2000");
jedis.set("k3000","v3000");
Set<String> keys = jedis.keys("*");
System.out.println(keys.size());
for(String key:keys){
System.out.println(key);
}
System.out.println(jedis.exists("k3000"));
System.out.println(jedis.ttl("k2000"));
System.out.println(jedis.get("k1000"));
}
(2)String
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.mset("str1", "v1", "str2", "v2", "str3", "v3");
System.out.println(jedis.mget("str1", "str2", "str3"));
}
(3)List
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.lpush("mylist","v1","v2","v3");
List<String> list = jedis.lrange("mylist", 0, -1);
for(String element : list){
System.out.println(element);
}
}
(4)set
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.sadd("sets1","set01","set02","set03","set04");
jedis.sadd("sets2","set02","set03","set04","set05");
Set<String> smembers = jedis.smembers("sets1");
for (String set : smembers) {
System.out.println(set);
}
System.out.println("===================");
jedis.srem("sets1","set02");
System.out.println(jedis.scard("sets1"));
Set<String> sinter = jedis.sinter("sets1", "sets2");
for (String s : sinter) {
System.out.println(s);
}
System.out.println("===================");
Set<String> sunion = jedis.sunion("sets1", "sets2");
for (String s : sunion) {
System.out.println(s);
}
System.out.println("===================");
Set<String> sdiff = jedis.sdiff("sets1", "sets2");
for (String s : sdiff) {
System.out.println(s);
}
}
(5)hash
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.hset("hash1","userName","zhangsan");
System.out.println(jedis.hget("hash1", "userName"));
HashMap<String, String> map = new HashMap<>();
map.put("userName","lisi");
map.put("age","20");
map.put("gender","nv");
jedis.hmset("hash2",map);
List<String> res = jedis.hmget("hash2", "userName", "age", "gender");
for (String re : res) {
System.out.println(re);
}
}
(6)zset
public static void main(String[] args) {
Jedis jedis = new Jedis("hadoop101", 6379);
System.out.println(jedis.ping()); // 输出PONG
jedis.zadd("zset01", 100d, "z3");
jedis.zadd("zset01", 90d, "l4");
jedis.zadd("zset01", 80d, "w5");
jedis.zadd("zset01", 70d, "z6");
Set<Tuple> zrange = jedis.zrangeWithScores("zset01", 0, -1);
for (Tuple tuple : zrange) {
System.out.println(tuple);
}
}
(7)连接池
为了节省每次连接redis服务带来的消耗,把连接好的实例反复利用。
通过参数管理连接的行为
代码如下
public class RedisUtil {
public static void main(String[] args) {
//Jedis jedis = new Jedis("hadoop101", 6379);
Jedis jedis = RedisUtil.getJedisFromPool();
System.out.println(jedis.ping()); // 输出PONG
jedis.close();
}
private static JedisPool jedisPool=null;
public static Jedis getJedisFromPool(){
if(jedisPool==null){
JedisPoolConfig jedisPoolConfig =new JedisPoolConfig();
jedisPoolConfig.setMaxTotal(200); //最大可用连接数
jedisPoolConfig.setMaxIdle(30); //最大闲置连接数
jedisPoolConfig.setMinIdle(10); //最小闲置连接数
jedisPoolConfig.setBlockWhenExhausted(true); //连接耗尽是否等待
jedisPoolConfig.setMaxWaitMillis(2000); //等待时间
jedisPoolConfig.setTestOnBorrow(true); //取连接的时候进行一下测试 ping pong
jedisPool=new JedisPool(jedisPoolConfig,"hadoop101", 6379 );
return jedisPool.getResource();
}else{
return jedisPool.getResource();
}
}
}
链接池参数说明
- MaxTotal:控制一个pool可分配多少个jedis实例,通过pool.getResource()来获取;如果赋值为-1,则表示不限制;如果pool已经分配了MaxTotal个jedis实例,则此时pool的状态为exhausted。
- maxIdle:控制一个pool最多有多少个状态为idle(空闲)的jedis实例;
- minIdle:控制一个pool最少有多少个状态为idle(空闲)的jedis实例;
- BlockWhenExhausted:连接耗尽是否等待
- MaxWaitMillis:表示当borrow一个jedis实例时,最大的等待毫秒数,如果超过等待时间,则直接抛JedisConnectionException;
- testOnBorrow:获得一个jedis实例的时候是否检查连接可用性(ping());如果为true,则得到的jedis实例均是可用的;
二 将人群包存放到Redis中
1 思路分析
- 查询出人群包 uids的集合
- 写入redis
- type:set(不需要有序,排除zset;需要单值排除hash;list中很多不是幂等操作,最终选择set)
- key:user_group: 101(user_group:user_group_id)
- value:uid …
- field score:无
- 写api:sadd
- 读api:smembers
- 失效:不是临时值,不设失效
2 代码实现
UserGroupMapper.java
添加方法
/**
* 数组无法存入到List中,mybatis进行封装,想要封装到List中,需要变为一行一行的值
* 数组是一行值,不同uid间以逗号分隔
* 将数组变为很多行,在ClickHouse中可以使用arrayJoin函数将数组炸开
* @param userGroupId
* @return
*/
@Select("select arrayJoin( bitmapToArray(us) ) as us from user_group where user_group_id=#{userGroupId}")
@DS("clickhouse")
public List<String> userGroupUidList(@Param("userGroupId") String userGroupId);
UserGroupServiceImpl.java
添加代码
// 3 人群包(包含所有uid)以应对高QPS访问
// redis(bitmap/set)
/**
* - 查询出人群包 uids的集合
*
* - 写入redis
* - type:set(不需要有序,排除zset;需要单值排除hash;list中很多不是幂等操作,最终选择set)
* - key:user_group: 101(user_group:user_group_id)
* - value:uid ...
* - field score:无
* - 写api:sadd
* - 读api:smembers
* - 失效:不是临时值,不设失效
*/
List<String> uidList = super.baseMapper.userGroupUidList(userGroup.getId().toString());
Jedis jedis = RedisUtil.getJedisFromPool();
String key = "user_group:" + userGroup.getId();
String[] uidArr = uidList.toArray(new String[]{});
jedis.sadd(key,uidArr);
jedis.close();
3 结果验证
在网页创建分群,然后在redis中查看是否存在数据
keys *
smembers user_group:id(会有具体的数字)
三 挖掘类标签
1 挖掘类标签与机器学习
挖掘类标签需要用算法挖掘用户的相关特征,比如:性别预测、年龄预测、 用户流失预测、风险欺诈预测。
相比统计、规则类这些通过专业人员制定明确规则的标签,挖掘类的标签完全是另一套处理思路。
获得挖掘标签过程:
2 模型建立过程
整个挖掘的过程的核心就是建立、完善模型的过程。
一个模型完善的过程是个没有尽头的迭代。
(1)数据预处理
主要是对数据的初步的清洗加工,这个过程一般可以在数仓中完成,然后在数仓中稍微的添加一些操作。
(2)特征工程
主要是特征的选择和提取。比如想预测用户的流失,那就要选择哪些指标字段会和用户的流失有比较强的相关性。要从数仓中,把这些指标提取出来并进一步加工。
除了获得特征,还需要“参考答案”,比如抽选出来的这些用户特征,那这些用户到底是不是流失的,要标记出来,用于机器学习。
特征的选取往往不能一蹴而就,需要反复的迭代尝试。
(3)算法选取
目前机器学习的算法种类繁多,比如分类算法领域中:决策树、随机森林、逻辑回归、GBDT、XGBoost。
回归算法领域中:线性回归、多项式回归、岭回归、Lasso回归、弹性回归。
在画像领域中,主要使用分类算法。但具体使用哪种分类算法,也是需要不断尝试验证的,没有一定的标准。
(4)模型训练
通过代码实现“数据 + 算法 = 模型”,可以使用scala调用sparkMLlib工具包实现机器学习训练,将模型存储在hdfs。
(5)评估优化
一般会把数据进行分组,训练组和验证组,然后对模型组进行准率的评估。
根据准确率,对模型进行优化:
优化一般主要是三个方面:
- 特征选取和提炼
- 算法的比较和选择
- 算法的参数调整
(6)投放使用、验证优化
把模型投放到实际的标签生产中去观察,比如预测流失的用户,一段时间是否真的会流失。
或者进行A/B测试,对预测的一部分用户采取某种措施,另一部分用户不作处理。观察两组人的变化效果。
通过实际生产中的预测效果,不断的反复调整模型、算法。
四 决策树算法
机器学习【决策树算法1】
机器学习【决策树算法2】
使用决策树需要解决的问题:
- 选取什么特征进行判断
- 特征判断的先后顺序
- 连续值如何切分
五 用SparkMLlib实现决策树的使用
1 利用流水线完成模型训练
训练 + 预测的完成过程如下图:
2 数据准备
(1)建表语句
create table student
( uid bigint ,
hair string,
height bigint ,
skirt string,
age string ,
gender string
)
(2)模拟数据
insert overwrite table student
values
( 1,'长发' ,155,'是', '80后','女' ),
( 2,'短发' ,156,'否', '90后','女' ),
( 3,'长发' ,157,'是', '00后','女' ),
( 4,'短发' ,158,'否', '80后','女' ),
( 5,'长发' ,159,'是', '90后','女' ),
( 6,'短发' ,160,'否', '00后','女' ),
( 7,'长发' ,161,'否', '80后','女' ),
( 8,'短发' ,162,'否', '90后','女' ),
( 9,'长发' ,163,'是', '00后','女' ),
( 10,'短发' ,164,'否', '80后','女' ),
( 11,'长发' ,165,'是', '90后','女' ),
( 12,'短发' ,166,'否', '00后','女' ),
( 13,'长发' ,167,'是', '80后','女' ),
( 14,'短发' ,168,'否', '90后','女' ),
( 15,'板寸' ,169,'是', '00后','女' ),
( 16,'短发' ,160,'否', '80后','女' ),
( 17,'长发' ,171,'是', '90后','女' ),
( 18,'短发' ,162,'否', '00后','女' ),
( 19,'长发' ,173,'是', '80后','女' ),
( 20,'短发' ,174,'否', '90后','女' ),
( 21,'长发' ,175,'是', '00后','女' ),
( 22,'短发' ,155,'否', '80后','女' ),
( 23,'长发' ,156,'否', '90后','女' ),
( 24,'短发' ,157,'否', '00后','女' ),
( 25,'长发' ,158,'否', '80后','女' ),
( 26,'短发' ,159,'否', '90后','女' ),
( 27,'长发' ,160,'是', '00后','女' ),
( 28,'短发' ,161,'否', '00后','女' ),
( 29,'长发' ,162,'是', '80后','女' ),
( 30,'短发' ,163,'否', '00后','女' ),
( 31,'长发' ,164,'是', '80后','女' ),
( 32,'短发' ,165,'否', '00后','女' ),
( 33,'长发' ,166,'是', '00后','女' ),
( 34,'短发' ,167,'否', '80后','女' ),
( 35,'长发' ,169,'是', '90后','女' ),
( 36,'短发' ,170,'否', '00后','女' ),
( 37,'长发' ,171,'是', '80后','女' ),
( 38,'短发' ,172,'是', '90后','女' ),
( 39,'长发' ,173,'否', '00后','女' ),
( 40,'长发' ,174,'否', '80后','女' ),
( 41,'短发' ,175,'是', '90后','女' ),
( 42,'短发' ,165,'否', '00后','女' ),
( 43,'短发' ,166,'是', '80后','女' ),
( 44,'长发' ,167,'否', '90后','女' ),
( 45,'短发' ,168,'是', '00后','女' ),
( 46,'短发' ,169,'否', '80后','女' ),
( 47,'长发' ,170,'是', '90后','女' ),
( 48,'短发' ,171,'否', '00后','女' ),
( 49,'长发' ,172,'是', '80后','女' ),
( 50,'短发' ,173,'否', '90后','女' ),
( 51,'短发' ,165,'否', '80后','男' ),
( 52,'板寸' ,166,'否', '90后','男' ),
( 51,'短发' ,167,'否', '00后','男' ),
( 52,'板寸' ,168,'否', '80后','男' ),
( 53,'短发' ,169,'否', '90后','男' ),
( 54,'短发' ,170,'否', '00后','男' ),
( 55,'短发' ,171,'否', '80后','男' ),
( 56,'板寸' ,172,'否', '90后','男' ),
( 57,'短发' ,173,'否', '00后','男' ),
( 58,'短发' ,174,'否', '80后','男' ),
( 59,'短发' ,175,'否', '90后','男' ),
( 60,'短发' ,176,'否', '00后','男' ),
( 61,'短发' ,177,'否', '80后','男' ),
( 62,'短发' ,178,'否', '90后','男' ),
( 63,'短发' ,179,'否', '00后','男' ),
( 64,'板寸' ,180,'否', '80后','男' ),
( 65,'短发' ,181,'否', '90后','男' ),
( 66,'短发' ,182,'否', '80后','男' ),
( 67,'短发' ,183,'否', '80后','男' ),
( 68,'短发' ,184,'否', '90后','男' ),
( 69,'短发' ,185,'否', '80后','男' ),
( 70,'短发' ,166,'否', '80后','男' ),
( 71,'短发' ,167,'否', '90后','男' ),
( 72,'板寸' ,168,'否', '00后','男' ),
( 73,'短发' ,169,'否', '80后','男' ),
( 74,'短发' ,170,'否', '90后','男' ),
( 75,'短发' ,171,'否', '00后','男' ),
( 76,'板寸' ,172,'否', '80后','男' ),
( 77,'短发' ,173,'否', '90后','男' ),
( 78,'短发' ,174,'否', '00后','男' ),
( 79,'短发' ,175,'否', '80后','男' ),
( 80,'板寸' ,176,'否', '90后','男' ),
( 81,'短发' ,177,'否', '00后','男' ),
( 82,'短发' ,178,'否', '80后','男' ),
( 83,'短发' ,179,'否', '90后','男' ),
( 84,'短发' ,180,'否', '80后','男' ),
( 85,'短发' ,181,'否', '80后','男' ),
( 86,'板寸' ,182,'否', '90后','男' ),
( 87,'短发' ,183,'否', '00后','男' ),
( 88,'短发' ,184,'否', '80后','男' ),
( 89,'短发' ,185,'否', '90后','男' ),
( 90,'板寸' ,184,'否', '00后','男' ),
( 91,'短发' ,171,'否', '80后','男' ),
( 92,'短发' ,172,'否', '90后','男' ),
( 93,'短发' ,173,'否', '00后','男' ),
( 94,'短发' ,174,'否', '80后','男' ),
( 95,'短发' ,175,'否', '90后','男' ),
( 96,'板寸' ,176,'否', '00后','男' ),
( 97,'短发' ,177,'否', '80后','男' ),
( 98,'板寸' ,178,'否', '90后','男' ),
( 99,'板寸' ,179,'否', '00后','男' ),
( 100,'长发' ,180,'否', '80后','男' ) ,
( 101,'长发' ,155,'是', '80后','女' ),
( 102,'短发' ,156,'否', '90后','女' ),
( 103,'长发' ,157,'是', '00后','女' ),
( 104,'短发' ,158,'否', '80后','女' ),
( 105,'长发' ,159,'是', '90后','女' ),
( 106,'短发' ,160,'否', '00后','女' ),
( 107,'长发' ,161,'否', '80后','女' ),
( 108,'短发' ,162,'否', '90后','女' ),
( 109,'长发' ,163,'是', '00后','女' ),
( 110,'短发' ,164,'否', '80后','女' )
将数据存放到hive中。
3 模块搭建
在user-profile-task1016下创建task-ml,如下图:
在pom.xml引入依赖
<dependencies>
<dependency>
<groupId>com.hzy.userprofile</groupId>
<artifactId>task-common</artifactId>
<version>1.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>3.0.0</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<!-- 该插件用于将Scala代码编译成class文件 -->
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.4.6</version>
<executions>
<execution>
<!-- 声明绑定到maven的compile阶段 -->
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
4 创建流水线对象
流水线PipeLine实际上就是执行一些预处理工作,其中
-
标签索引:参考答案,数据集中的最后一列,将标签值转换为矢量值,也就是将
男,女
转换为0,1
。按照出现概率的大小次序排序,概率越大,矢量越小。
-
特征聚合:在原始数据中选择特征列,并集合成一列。
-
特征索引:将特征集合中的原值转换为矢量值,转换规则同标签索引。
需要识别哪些是连续值特征,哪些是离散值特征,具体判断标准:底层会设置一个阈值,高于阈值判断为连续值,否则为离散值,即小于等于。
class MyPipeline {
// 5 用于接收此阶段最终的结果
var pipeline:Pipeline = null
//最大分类树(用于识别连续值特征和分类特征),用于3创建特征索引列
private var maxCategories=5
// 最大分支数
private var maxBins=5
// 最大树深度
private var maxDepth=5
//最小分支包含数据条数
private var minInstancesPerNode=1
//最小分支信息增益
private var minInfoGain=0.0
}
5 增加流水线组件(三个徒弟,一个师傅)
(1)创建标签索引
// 用于1 标签索引
var labelColName: String = null
// 从外部注入
def setLabelColName(labelColName: String) : MyPipeline = {
this.labelColName = labelColName
this
}
// 1 创建标签索引
def createLabelIndexer():StringIndexer = {
// 输入的原始数据 结构为DF
val indexer = new StringIndexer()
// 设置输入列和输出列
// 输入列为数据的最后一列,通过外部传递进来
// 输出列与外部数据没有关系,直接固定下来即可
// 最终会在DF中增加一列,名称可以自己设置
indexer.setInputCol(labelColName).setOutputCol("label_index")
indexer
}
(2)创建特征集合
// 用于2 特征集合
var featureColNames:Array[String] = null
// 从外部注入
def setFeatureColNames(featureColNames: Array[String]) : MyPipeline = {
this.featureColNames = featureColNames
this
}
// 2 创建特征集合列
def createFeatureAssemble():VectorAssembler = {
val assembler = new VectorAssembler()
// 可以将多个列设置为特征,也可以称为维度,输出列只有一个
assembler.setInputCols(featureColNames).setOutputCol("feature_assemble")
assembler
}
(3)创建特征向量索引
// 3 创建特征索引列
def createFeatureIndexer():VectorIndexer = {
val indexer = new VectorIndexer()
// 特征集合的输出就是特征索引的输入
// 此外还需要设置阈值,用于判断是线性值还是离散值
indexer.setInputCol("feature_assemble").setOutputCol("feature_index").setMaxCategories(maxCategories)
indexer
}
(4)创建分类器
// 4 创建分类器
def createClassifier():DecisionTreeClassifier ={
val classifier = new DecisionTreeClassifier()
// 设置标签列(1),设置特征列(3),设置预测列(自己起名)
classifier.setLabelCol("label_index").setFeaturesCol("feature_index").setPredictionCol("prediction_col")
classifier
}
6 初始化对象
def init():MyPipeline = {
// StringIndexer、VectorAssembler、VectorIndexer、DecisionTreeClassifier
// 以上四者的父类都是PipelineStage,可以理解为是流水线上的一个环节
// 以上前三者都是这个环节中的工人,最后一个是这三个人的师傅
// 执行此方法,师徒四人就要上岗干活了!
pipeline = new Pipeline().setStages( Array(
createLabelIndexer,
createFeatureAssemble,
createFeatureIndexer,
createClassifier
))
this
}
7 训练和预测
// 6 训练,得到模型
def train(dataFrame:DataFrame):Unit ={
pipelineModel = pipeline.fit(dataFrame)
}
// 7 预测
def predict(dataFrame: DataFrame):DataFrame ={
val predictedDataFrame1: DataFrame = pipelineModel.transform(dataFrame)
predictedDataFrame1
}
8 调用工具类,进行测试
新建类StudentGenderTrain,添加配置文件,如下图
源码如下:
package com.hzy.userprofile.ml.train
import com.hzy.userprofile.ml.pipeline.MyPipeline
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
object StudentGenderTrain {
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setAppName("student_gender_train.app").setMaster("local[*]")
val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
println("查询数据")
// 1 查询数据
val sql =
s"""
| select
| uid,
| case hair when '长发' then 101 when '短发' then 102 when '板寸' then 103 end as hair,
| height,
| case skirt when '是' then 111 when '否' then 222 end as skirt,
| case age when '00后' then 100 when '90后' then 90 when '80后' then 80 end as age,
| gender
| from
| default.student
|""".stripMargin
println(sql)
val dataFrame: DataFrame = sparkSession.sql(sql)
println("切分数据")
// 2 切分数据:训练集和测试集(82 或 73)
val Array(trainDF,testDF) = dataFrame.randomSplit(Array(0.8,0.2))
println("创建myPipeine")
// 3 创建myPipeine
val myPipeline: MyPipeline = new MyPipeline()
.setLabelColName("gender")
.setFeatureColNames(Array("hair","height","skirt","age"))
.init()
println("进行训练")
// 4 进行训练
myPipeline.train(trainDF)
println("进行预测")
// 5 进行预测
val predictedDataFrame: DataFrame = myPipeline.predict(testDF)
println("打印预测结果")
// 6 打印预测结果
predictedDataFrame.show(100,false)
}
}
运行之前需要配置hadoop用户名,集体结果分析如下:
- 前六列为原始值列,将文字转化为数字。
- lable_index:标签矢量值,男为1女为0(女生数量多,矢量值小),标准答案。
- feature_assemble:特征集合列,将所有特征整合成一列。
- feature_index:将特征集合转化为矢量的集合,其中连续值不会进行转化。
- rawPrediction:机器任务男和女的权重分别是的多少,前面为0号矢量的权重。
- probability:根据权重预测结果。
- prediction_col:最终预测结果。
六 完整代码
1 MyPipeline
package com.hzy.userprofile.ml.pipeline
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer}
import org.apache.spark.sql.DataFrame
class MyPipeline {
// 5 用于接收此阶段最终的结果
var pipeline:Pipeline = null
def init():MyPipeline = {
// StringIndexer、VectorAssembler、VectorIndexer、DecisionTreeClassifier
// 以上四者的父类都是PipelineStage,可以理解为是流水线上的一个环节
// 以上前三者都是这个环节中的工人,最后一个是这三个人的师傅
// 执行此方法,师徒四人就要上岗干活了!
pipeline = new Pipeline().setStages( Array(
createLabelIndexer(),
createFeatureAssemble(),
createFeatureIndexer(),
createClassifier()
))
this
}
// 模型:通过训练得来
var pipelineModel:PipelineModel = null
//最大分类树(用于识别连续值特征和分类特征),用于3创建特征索引列
private var maxCategories=5
// 最大分支数
private var maxBins=5
// 最大树深度
private var maxDepth=5
//最小分支包含数据条数
private var minInstancesPerNode=1
//最小分支信息增益
private var minInfoGain=0.0
// 用于1 标签索引
var labelColName: String = null
// 用于2 特征集合
var featureColNames:Array[String] = null
// 从外部注入
def setLabelColName(labelColName: String) : MyPipeline = {
this.labelColName = labelColName
this
}
// 从外部注入
def setFeatureColNames(featureColNames: Array[String]) : MyPipeline = {
this.featureColNames = featureColNames
this
}
// 1 创建标签索引
def createLabelIndexer():StringIndexer = {
// 输入的原始数据 结构为DF
val indexer = new StringIndexer()
// 设置输入列和输出列
// 输入列为数据的最后一列,通过外部传递进来
// 输出列与外部数据没有关系,直接固定下来即可
// 最终会在DF中增加一列,名称可以自己设置
indexer.setInputCol(labelColName).setOutputCol("label_index")
indexer
}
// 2 创建特征集合列
def createFeatureAssemble():VectorAssembler = {
val assembler = new VectorAssembler()
// 可以将多个列设置为特征,也可以称为维度,输出列只有一个
assembler.setInputCols(featureColNames).setOutputCol("feature_assemble")
assembler
}
// 3 创建特征索引列
def createFeatureIndexer():VectorIndexer = {
val indexer = new VectorIndexer()
// 特征集合的输出就是特征索引的输入
// 此外还需要设置阈值,用于判断是线性值还是离散值
indexer.setInputCol("feature_assemble").setOutputCol("feature_index").setMaxCategories(maxCategories)
indexer
}
// 4 创建分类器
def createClassifier():DecisionTreeClassifier ={
val classifier = new DecisionTreeClassifier()
// 设置标签列(1),设置特征列(3),设置预测列(自己起名)
classifier.setLabelCol("label_index").setFeaturesCol("feature_index").setPredictionCol("prediction_col")
classifier
}
// 6 训练,得到模型
def train(dataFrame:DataFrame):Unit ={
pipelineModel = pipeline.fit(dataFrame)
}
// 7 预测
def predict(dataFrame: DataFrame):DataFrame ={
val predictedDataFrame1: DataFrame = pipelineModel.transform(dataFrame)
predictedDataFrame1
}
}
2 StudentGenderTrain
package com.hzy.userprofile.ml.train
import com.hzy.userprofile.ml.pipeline.MyPipeline
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
object StudentGenderTrain {
def main(args: Array[String]): Unit = {
val sparkConf: SparkConf = new SparkConf().setAppName("student_gender_train.app").setMaster("local[*]")
val sparkSession: SparkSession = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
println("查询数据")
// 1 查询数据
val sql =
s"""
| select
| uid,
| case hair when '长发' then 101 when '短发' then 102 when '板寸' then 103 end as hair,
| height,
| case skirt when '是' then 111 when '否' then 222 end as skirt,
| case age when '00后' then 100 when '90后' then 90 when '80后' then 80 end as age,
| gender
| from
| default.student
|""".stripMargin
println(sql)
val dataFrame: DataFrame = sparkSession.sql(sql)
println("切分数据")
// 2 切分数据:训练集和测试集(82 或 73)
val Array(trainDF,testDF) = dataFrame.randomSplit(Array(0.8,0.2))
println("创建myPipeine")
// 3 创建myPipeine
val myPipeline: MyPipeline = new MyPipeline()
.setLabelColName("gender")
.setFeatureColNames(Array("hair","height","skirt","age"))
.init()
println("进行训练")
// 4 进行训练
myPipeline.train(trainDF)
println("进行预测")
// 5 进行预测
val predictedDataFrame: DataFrame = myPipeline.predict(testDF)
println("打印预测结果")
// 6 打印预测结果
predictedDataFrame.show(100,false)
}
}