基于Scala开发Spark ML的ALS推荐模型实战

news2024/11/23 23:41:40

推荐系统,广泛应用到电商,营销行业。本文通过Scala,开发Spark ML的ALS算法训练推荐模型,用于电影评分预测推荐。

算法简介

ALS算法是Spark ML中实现协同过滤的矩阵分解方法。

ALS,即交替最小二乘法(Alternating Least Squares),是协同过滤技术中的一种经典算法。它通过对用户和物品的潜在特征进行建模,来预测用户对未知物品的评分或偏好。具体介绍如下:

  1. 矩阵分解模型:在推荐系统中,我们通常有一个用户-物品的评分矩阵,其中行表示用户,列表示物品,矩阵中的值代表用户对物品的评分。然而,这个矩阵通常是非常稀疏的,因为用户只给少数物品评分。ALS算法就是在这样的不完整评分矩阵上操作,通过矩阵分解来补全缺失值,进而产生推荐。
  2. 算法原理:ALS算法的核心思想是通过迭代过程更新用户和物品的潜在因子向量。在每次迭代中,一个评分被建模为用户潜在特征向量和物品潜在特征向量的点积,加上一个偏差项。通过最小化实际评分和预测评分之间的差异来不断优化这些潜在特征向量。
  3. Spark ML实现:在Spark ML库中,ALS算法被用于处理大规模的数据集,并提供了多种参数以适应不同的数据特性和需求。例如,可以设置潜在因子的数量、正则化参数、迭代次数等。此外,Spark ML的ALS还支持隐式反馈数据的变体,这对于无法获取明确评分的数据非常有用。

总的来说,ALS是一种强大的推荐系统算法,尤其适用于处理大规模稀疏数据集。通过合理地选择和调整参数,可以在保持高效计算的同时获得良好的推荐质量。

代码实战

pom.xml文件更新,加入相关依赖

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>sparkGNU2023</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <scala.version>2.13</scala.version>
        <spark.version>3.4.1</spark.version>
        <log4j.version>1.2.17</log4j.version>
        <slf4j.version>1.7.22</slf4j.version>
    </properties>


    <dependencies>

        <!--日志相关依赖-->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>jcl-over-slf4j</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-log4j12</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>${log4j.version}</version>
        </dependency>


        <dependency>
            <groupId>com.thoughtworks.paranamer</groupId>
            <artifactId>paranamer</artifactId>
            <version>2.8</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.13</artifactId>
            <version>3.4.1</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka-0-10_2.13</artifactId>
            <version>3.4.1</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka-0-8_2.11</artifactId>
            <version>2.4.8</version>
        </dependency>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>8.0.30</version>
        </dependency>

        <dependency>
            <groupId>org.apache.flume.flume-ng-clients</groupId>
            <artifactId>flume-ng-log4jappender</artifactId>
            <version>1.11.0</version>
        </dependency>

        <!--        flume 拦截器相关依赖-->
        <dependency>
            <groupId>org.apache.flume</groupId>
            <artifactId>flume-ng-core</artifactId>
            <version>1.9.0</version>
            <scope>provided</scope>
        </dependency>

        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.62</version>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.8.1</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-assembly-plugin</artifactId>
                <version>3.6.0</version>
                <configuration>
                    <descriptorRefs>
                        <descriptorRef>jar-with-dependencies</descriptorRef>
                    </descriptorRefs>
                </configuration>
                <executions>
                    <execution>
                        <id>make-assembly</id>
                        <phase>package</phase>
                        <goals>
                            <goal>single</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
    
</project>

训练ALS模型

基于scala训练ALS模型

package base.charpter10

import breeze.linalg.sum
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.functions.{col, count, explode, when}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * @projectName sparkGNU2023  
 * @package base.charpter10  
 * @className base.charpter10.MovieRecommender  
 * @description ${description}  
 * @author pblh123
 * @date 2024/3/29 15:18
 * @version 1.0
 *
 */
    
object MovieRecommender {

  def main(args: Array[String]): Unit = {
    // 创建Spark会话
    val spark = SparkSession.builder()
      .appName("MovieRecommender")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    // 假设我们有一个用户-物品评分数据集,格式为(userId, itemId, rating)
    /**
     * UserID,MovieID,Rating,Timestamp
     *  1,1193,5,978300760
     *  1,661,3,978302109
     */
    // 指定CSV文件的路径,以及解析选项
    val csvFilePath = "data/ratings.csv"
    val csvOptions = Map(
      "header" -> "true", // 是否有列名头
      "inferSchema" -> "true", // 是否自动推断数据类型
    "encoding" -> "UTF-8", // 如果有特定的编码格式,例如对于包含中文的CSV文件:
    )

    // 读取CSV文件并创建DataFrame
    val ratingsDF = spark.read.format("csv")
      .options(csvOptions)
      .load(csvFilePath)

    // 显示DataFrame的前几行以验证数据是否正确加载
    println("查看原始据数据样例:")
    ratingsDF.show(5)

    val ratings: DataFrame = ratingsDF.select("UserID", "MovieID", "Rating")
      .withColumnRenamed("UserID", "userId")
      .withColumnRenamed("MovieID", "itemId")
      .withColumnRenamed("Rating", "rating")


    // 将数据集分割为训练集和测试集
    val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))

    println("查看训练集数据")
    training.show(5)
    println("查看测试集数据")
    test.show(5)


    // 设置ALS参数
    // 创建一个ALS实例并配置参数
    val als = new ALS()
      .setMaxIter(10) // 设置最大迭代次数为5,10,本地测试时,设置过大,会报错
      .setRegParam(0.01) // 设置正则化参数为0.01
      .setUserCol("userId") // 设置用户列名为"userId"
      .setItemCol("itemId") // 设置物品列名为"itemId"
      .setRatingCol("rating") // 设置评分列名为"rating"

    /**
     * ALS(Alternating Least Squares)是一种基于矩阵分解的协同过滤算法,用于处理用户和物品之间的评分数据。各参数说明如下:
     *  setMaxIter: 设置最大迭代次数,决定模型训练的精细程度。迭代次数越多,模型通常越精确,但训练时间也可能更长。
     *  setRegParam: 设置正则化参数,用于控制模型的复杂度和过拟合程度。较小的正则化参数值可能导致模型过复杂,容易过拟合;较大的值则可能导致模型过于简单,欠拟合。
     *  setUserCol, setItemCol, setRatingCol: 分别设置用户ID列、物品ID列和评分列的名称。这些列名根据实际的数据结构来确定,用于告诉ALS算法在哪些列中查找用户、物品和评分信息。
     */


    // 训练ALS模型
    println("开始训练模型")
    val model = als.fit(training)

    // 对测试集进行预测
    val predictions = model.transform(test)

    predictions.show()
    predictions.filter($"rating".isNotNull && $"prediction".isNotNull).count() // 确认有非空的评分和预测值


    // 评估模型
    val evaluator = new RegressionEvaluator()
      .setMetricName("rmse")
      .setLabelCol("rating")
      .setPredictionCol("prediction")
    val rmse = evaluator.evaluate(predictions)
    println(s"Root-mean-square error = $rmse")

    // 为用户生成推荐
    // 该函数是基于一个模型(model)为所有用户推荐项目的函数。它将为每个用户推荐5个项目
    /**
     * +------+--------------------------------------------------------------------------------------------+
     *  |userId|recommendations[{itemid,pred_rating},{itemid,pred_rating},...]                                                                             |
     *  +------+--------------------------------------------------------------------------------------------+
     *  |12    |[{1864, 9.721167}, {2964, 8.815781}, {3867, 8.480173}, {1539, 7.8904114}, {563, 7.8829007}] |
     *  |22    |[{2964, 6.090676}, {3215, 5.6165895}, {1534, 5.4731245}, {718, 5.462125}, {2632, 5.4482727}]|
     */

    val userRecs = model.recommendForAllUsers(5)
    userRecs.show(5,false)
    println("保存预测结果")
//    userRecs.write.mode("overwrite").parquet("models/recomALSmodel") // 保存为parquet格式,一般用于集群中
// userRecs是一个DataFrame,其中"recommendations"列是数组类型
    val explodedUserRecs = userRecs.withColumn("recommendations", explode($"recommendations"))
      .select($"userId", $"recommendations.itemId".as("itemId"), $"recommendations.rating".as("PredRating"))
    explodedUserRecs.write.mode("overwrite").format("csv").save("predictRes/recomALS")  // PC 调试使用

    // 保存模型到指定路径
    val modelPath = "models/recomALSmodel"
    model.write.overwrite().save(modelPath)
    println(s"Model saved to $modelPath")

    // 停止Spark会话
    spark.stop()

    /*
    当程序试图停止Spark会话时,可能会触发清理临时文件的操作,
    从而导致出现NoSuchFileException异常。通常情况下,这不是代码逻辑的问题,
    而是Spark内部在清理资源时可能出现的问题。
    可以尝试重启Spark环境或者适当增大Spark的临时目录空间来避免此类问题。
     */

  }

}

运行代码,效果图如下

TodoList:目前RMSE计算出问题,原数据清洗没有做,模型参数还可以调整。后期调整更新后,再发一篇文章。

使用训练的模型预测新数据

scala开发应用模型demo代码

package base.charpter10

import org.apache.spark.ml.recommendation.ALSModel
import org.apache.spark.sql.SparkSession

/**
 * @projectName sparkGNU2023  
 * @package base.charpter10  
 * @className base.charpter10.RecommendationModelLoadDemo  
 * @description ${description}  
 * @author pblh123
 * @date 2024/3/29 15:36
 * @version 1.0
 *
 */
    
object RecommendationModelLoadDemo {
  def main(args: Array[String]): Unit = {
    // 创建Spark会话
    val spark = SparkSession.builder().master("local[*]")
      .appName("RecommendationModelUsageDemo")
      .getOrCreate()

    import spark.implicits._

    // 加载之前保存的ALS模型
    val modelPath = "models/recomALSmodel"
    val loadedModel: ALSModel = ALSModel.load(modelPath)

    // 假设我们有一些新的用户-物品对,我们想要预测它们的评分
    val userItemPairs = Seq(
      (1, 4), // 用户1对物品4的评分预测
      (2, 2) // 用户2对物品2的评分预测
    ).toDF("userId", "itemId")

    // 使用模型进行评分预测
    val predictions = loadedModel.transform(userItemPairs)
    predictions.show()

    // 现在,假设我们想要为用户1生成前N个推荐物品
    val numRecommendations = 5 // 为用户推荐的物品数量
    val userRecs = loadedModel.recommendForAllUsers(numRecommendations)
    userRecs.show(5,false)

    // 停止Spark会话
    spark.stop()
  }

}

运行效果如下

评估效果说明:目前的预测评分不合理,是因为模型没有经过精挑,优化,预测的记过会依据预测评分高低排序,选取得分高的前5个结果返回。后期模型调优后,结果就正常了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1564816.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

练习 16 Web [极客大挑战 2019]LoveSQL

extractvalue(1,concat(‘~’, (‘your sql’) ) )报错注入&#xff0c;注意爆破字段的时候表名有可能是table_name不是table_schema 有登录输入框 常规尝试一下 常规的万能密码&#xff0c;返回了一个“admin的密码”&#xff1a; Hello admin&#xff01; Your password is…

Java获取IP地址以及MAC地址(附Demo)

目录 前言1. IP及MAC2. 特定适配器 前言 需要获取客户端的IP地址以及MAC地址 import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class test {public static void main(String[] args) {try {// 执行命令Process process…

基于springboot实现房屋租赁管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现房屋租赁系统演示 摘要 房屋是人类生活栖息的重要场所&#xff0c;随着城市中的流动人口的增多&#xff0c;人们对房屋租赁需求越来越高&#xff0c;为满足用户查询房屋、预约看房、房屋租赁的需求&#xff0c;特开发了本基于Spring Boot的房屋租赁系统。 …

蓝桥杯第八届c++大学B组详解

目录 1.购物单 2.等差素数列 3.承压计算 4.方格分割 5.日期问题 6.包子凑数 7.全球变暖 8.k倍区间 1.购物单 题目解析&#xff1a;就是将折扣字符串转化为数字&#xff0c;进行相加求和。 #include<iostream> #include<string> #include<cmath> usin…

【Python项目】基于django的【医用耗材网上申领系统】

医院信息化是社会发展的一个重要标志&#xff0c;它涉及到医院的各个方面&#xff0c;包括人员和物资&#xff0c;因此受到社会各界的广泛关注。近年来&#xff0c;随着医疗耗材数量的不断增加&#xff0c;如何有效管理这些耗材已经成为管理人员、医生以及社会各方面共同面临的…

【Web】记录Polar靶场<困难>难度题一遍过

目录 上传 PHP是世界上最好的语言 非常好绕的命令执行 这又是一个上传 网站被黑 flask_pin veryphp 毒鸡汤 upload tutu Unserialize_Escape 自由的文件上传系统​​​​​​​ ezjava 苦海 你想逃也逃不掉 safe_include CB链 phar PHP_Deserializatio…

Web CSS笔记3

一、边框弧度 使用它你就可以制作盒子边框圆角 border-radius&#xff1a;1个值四个圆角值相同2个值 第一个值为左上角与右下角&#xff0c;第二个值为右上角与左下角3个值第一个值为左上角, 第二个值为右上角和左下角&#xff0c;第三个值为右下角4个值 左上角&#xff0c;右…

Mac下Docker Desktop starting的解决方法

记录下自己在新增了一个新的容器后&#xff0c;Disk Size过大导致启动Docker Desktop会一直卡在Docker Desktop starting&#xff0c;并且重启无效的解决方法。该方法无需重新卸载&#xff0c;并且能保留原有的镜像和容器。 一、确认问题 首先确认Docker.raw大小以确认是否和笔…

异地文件如何共享访问?

异地文件共享访问是一种让不同地区的用户能够快速、安全地共享文件的解决方案。人们越来越需要在不同地点之间共享文件和数据。由于复杂的网络环境和安全性的问题&#xff0c;实现异地文件共享一直是一个挑战。 为了解决这个问题&#xff0c;许多公司和组织研发了各种异地文件共…

xilinx fpga程序固化

一、前言 xilinx 旗下的产品主要有包含有处理器的SOC系列&#xff0c;也有只有纯逻辑的fpga&#xff0c;两者的程序固化的方法并不相同&#xff0c;本文介绍只包含纯逻辑而不涉及处理器的fpga的代码固化。 二、固化流程 将工程综合&#xff0c;实现&#xff0c;并得到比特流…

【蓝桥杯第十三届省赛B组】(部分详解)

九进制转十进制 #include <iostream> #include<math.h> using namespace std; int main() {cout << 2*pow(9,3)0*pow(9,2)2*pow(9,1)2*pow(9,0) << endl;return 0; }顺子日期 #include <iostream> using namespace std; int main() {// 请在此…

数据如何才能供得出、流得动、用得好、还安全

众所周知&#xff0c;数据要素已经列入基本生产要素&#xff0c;同时成立国家数据局进行工作统筹。目前数据要素如何发挥其价值&#xff0c;全国掀起了一浪一浪的热潮。 随着国外大语言模型的袭来&#xff0c;国内在大语言模型领域的应用也大放异彩&#xff0c;与此同时&#x…

更高效、更简洁的 SQL 语句编写丨DolphinDB 基于宏变量的元编程模式详解

元编程&#xff08;Metaprogramming&#xff09;指在程序运行时操作或者创建程序的一种编程技术&#xff0c;简而言之就是使用代码编写代码。通过元编程将原本静态的代码通过动态的脚本生成&#xff0c;使程序员可以创建更加灵活的代码以提升编程效率。 在 DolphinDB 中&#…

Android14之BpBinder构造函数Handle拆解(二百零四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

python文件处理:解析docx/word文件文字、图片、复选框

前言 因为一些项目原因&#xff0c;我需要提供解析docx内容功能。本来以为这是一件比较简单的工作&#xff0c;没想到在解析复选框选项上吃了亏&#xff0c;并且较长一段时间内通过各种渠道都没有真正解决这一问题&#xff0c;反而绕了远路。 终于&#xff0c;我在github pytho…

Kafka架构概述

Kafka的体系结构 Kafka是由Apache软件基金会管理的一个开源的分布式数据流处理平台。Kafka具有支持消息的发布/订阅模式、高吞吐量与低延迟、持久化、支持水平扩展、高可用性等特点。可以将Kafka应用于大数据实时处理、高性能数据管道、流分析、数据集成和关键任务应用等场景。…

测开——基础理论面试题整理

1. 测试流程 需求了解分析需求评审制定测试计划【包括测试人员、时间、每人负责的模块、测试的风险项以及预防】编写自动化测试用例 —— 测试评审【尽量丰富测试点】编写测试框架和脚本&#xff08;若是功能测试 可省去这步骤&#xff09;执行测试提交缺陷报告测试分析与评审…

BIM转Power BI数据集

在本博客中&#xff0c;我们将了解如何使用从 SSAS 表格、Power BI Desktop 或 Power BI 服务数据集中提取的 Model.bim 文件在本地或 PBI 服务上生成新数据集。 1、设置&#xff08;SSAS 表格和 PBI 服务通用&#xff09; 我建议你创建一个专门用于此任务的新 Python 环境&a…

docker部署DOS游戏

下载镜像 docker pull registry.cn-beijing.aliyuncs.com/wuxingge123/dosgame-web-docker:latestdocker-compose部署 vim docker-compose.yml version: 3 services:dosgame:container_name: dosgameimage: registry.cn-beijing.aliyuncs.com/wuxingge123/dosgame-web-docke…

微软云学习环境

微软公有云 - Microsoft Azure 本文介绍通过微软学习中心Microsoft Learn来免费试用Azure上的服务&#xff0c;也不需要绑定信用卡。不过每天只有几个小时的时间。 官网 https://docs.microsoft.com/zh-cn/learn/ 实践 比如创建虚拟机&#xff0c;看到自己的账号下多了Learn的…