1、broadcast广播
在Spark中,broadcast是一种优化技术,它可以将一个只读变量缓存到每个节点上,以便在执行任务时使用。这样可以避免在每个任务中重复传输数据。
2、构建缓存
import org.apache.spark.sql.SparkSession
import org.apache.spark.broadcast.Broadcast
import com.alibaba.fastjson.JSONObject
// 定义全局缓存单例对象
object GlobalCache extends Serializable {
// 广播变量,用于存储缓存数据
private var cacheData: Broadcast[collection.mutable.Map[String, JSONObject]] = _
// 设置 SparkSession 和广播变量
def setSparkSession(spark: SparkSession): Unit = {
cacheData = spark.sparkContext.broadcast(collection.mutable.Map.empty[String, JSONObject])
}
// 按订单ID和用户ID缓存JSONObject对象
def cacheJSONObject(orderId: String, userId: String, jsonObject: JSONObject): Unit = {
// 获取广播变量的值并进行修改
val data = cacheData.value
data.synchronized {
data.put(generateKey(orderId, userId), jsonObject)
}
}
// 根据订单ID和用户ID删除缓存的JSONObject对象
def removeJSONObject(orderId: String, userId: String): Unit = {
// 获取广播变量的值并进行修改
val data = cacheData.value
data.synchronized {
data.remove(generateKey(orderId, userId))
}
}
// 根据订单ID和用户ID获取缓存的JSONObject对象
def getJSONObjet(orderId: String, userId: String): JSONObject = {
// 获取广播变量的值并进行访问
val data = cacheData.value
data.synchronized {
data.get(generateKey(orderId, userId)).orNull
}
}
// 生成缓存键,使用订单ID和用户ID拼接
private def generateKey(orderId: String, userId: String): String = s"$orderId|$userId"
}
3、缓存测试
import org.apache.spark.sql.SparkSession
import org.apache.spark.broadcast.Broadcast
import com.alibaba.fastjson.JSONObject
import org.apache.log4j.{Level, Logger}
object CacheTest {
Logger.getLogger("org").setLevel(Level.ERROR)
Logger.getRootLogger().setLevel(Level.ERROR) // 设置日志级别
def addItem(orderId:String, userId:String, name:String): Unit = {
val jsonObject = new JSONObject()
jsonObject.put("name", name)
// 缓存JSONObject对象
GlobalCache.cacheJSONObject(orderId, userId, jsonObject)
}
def getCache(orderId: String, userId: String): JSONObject = {
// 获取缓存的JSONObject对象
GlobalCache.getJSONObjet(orderId, userId)
}
def delItem(orderId:String, userId:String): Unit = {
// 删除缓存的JSONObject对象
GlobalCache.removeJSONObject(orderId, userId)
}
def getSparkSession(appName: String, localType: Int): SparkSession = {
val builder: SparkSession.Builder = SparkSession.builder().appName(appName)
if (localType == 1) {
builder.master("local[8]") // 本地模式,启用8个核心
}
val spark = builder.getOrCreate() // 获取或创建一个新的SparkSession
spark.sparkContext.setLogLevel("ERROR") // Spark设置日志级别
spark
}
def main(args: Array[String]): Unit = {
println("Start CacheTest")
val spark: SparkSession = getSparkSession("CacheTest", 1)
GlobalCache.setSparkSession(spark) // 构造全局缓存
addItem("001", "456", "苹果") // 添加元素
addItem("002", "789", "香蕉") // 添加元素
var cachedObject = getCache("001", "456")
println(s"Cached Object: $cachedObject")
delItem("001", "456") // 删除元素
cachedObject = getCache("001", "456")
println(s"Cached Object: $cachedObject")
spark.stop()
}
}
4、控制台输出
Start CacheTest
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Cached Object: {"name":"苹果"}
Cached Object: null
Process finished with exit code 0