推荐算法中经典排序算法GBDT+LR

news2025/2/26 22:57:02

文章目录

    • 逻辑回归模型
      • 逻辑回归对于特征处理的优势
      • 逻辑回归处理特征的步骤
    • GBDT算法
    • GBDT+LR算法
      • GBDT + LR简单代码实现

逻辑回归模型

  • 逻辑回归(LR,Logistic Regression)是一种传统机器学习分类模型,也是一种比较重要的非线性回归模型,其本质上是在线性回归模型的基础上,加了一个Sigmoid函数(也就是非线性映射),由于其简单、高效、易于并行计算的特点,在工业界受到了广泛的应用。

  • 使用LR模型主要是用于分类任务,通常情况下也都是二分类任务,一般在推荐系统的业务中,会使用LR作为Baseline模型快速上线。

  • 从本质上来讲,逻辑回归和线性回归一样同属于广义线性模型。虽然说逻辑回归可以实现回归预测,但是在推荐算法中,我们都将其看作是线性模型并把它应用在分类任务中。

  • 总结:逻辑回归实际上就是在数据服从伯努利分布的假设下,通过极大似然的方法,运用梯度下降算法对参数进行求解,从而达到二分类。

  • Q&A

    • Q:为什么在推荐系统中使用逻辑回归而不是线性回归?
    • A:原因主要有三点
      1. 虽然两者都是处理二分类问题,但是线性回归主要是处理连续性特征,例如一个人的身高体重,而推荐系统中对于一个物品推荐给用户,物品的特征是离散的。因此不好用线性回归作为预测
      2. 逻辑回归模型具有较高的可解释性,可以帮助我们更好地理解推荐系统中不同因素对推荐结果的影响,而线性回归在这一点上无法实现我们的需求。
      3. 推荐系统中往往存在很多的噪音,逻辑回归可以更好地处理异常值,避免推荐结果被干扰。

逻辑回归对于特征处理的优势

  1. 特征选择:可以使用正则化技术来选择重要的特征,提高模型效率和准确性。
  2. 处理非线性特征:可以通过引入多项式和交互特征来处理非线性特征,从而预测结果更加准确。
  3. 处理缺失值和异常值:可以通过处理缺失值和异常值使得模型更佳健壮,从而预测结果更加准确。
  4. 训练模型速度:训练速度相对较快。

逻辑回归处理特征的步骤

  1. 特征选择:逻辑回归可以使用正则化技术(L1、L2正则化)来选择最重要的特征,从而降低维度并去除无关特征。选择相关的特征有助于提高模型的稳定性和准确性。
  2. 处理缺失值和异常值:逻辑回归可以使用缺失值插补和异常值检测来处理缺失值和异常值,从而避免对模型产生影响,提高模型的健壮性。
  3. 处理非线性特征:逻辑回归可以通过引入多项式和交互特征来处理非线性特征,从而增强模型的表现力。以多项式模型为例,逻辑回归可以使用幂函数或指数函数对特征进行转换,从而处理非线性变量。
  4. 特征标准化:逻辑回归可以使用特征标准化来消除特征数据值的量纲影响,避免数值范围大的特征对模型产生很大影响。
  5. 特征工程:逻辑回归也可以使用特征工程来创建新的特征,例如聚合或拆分现有特征、提取信号等。这有助于发现与目标变量相关的新信息,从而改进对数据的理解。

GBDT算法

  • GBDT(Gradient Boosting Decision Tree)算法是一种基于决策树的集成学习算法,它的学习方式是梯度提升,它通过不断训练决策树来提高模型的准确性。GBDT在每一次训练中都利用当前的模型进行预测,并将预测误差作为新的样本权重,然后训练下一棵决策树模型来拟合加权后的新数据。
  • GBDT中的B代表的是Boosting算法,Boosting算法的基本思想是通过将多个弱分类器线性组合形成一个强分类器,达到优化训练误差和测试误差的目的。具体应用时,每一轮将上一轮分类错误的样本重新赋予更高的权重,这样一来,下一轮学习就容易重点关注错分样本,提高被错分样本的分类准确率。GBDT由多棵CART树组成,本质是多颗回归树组成的森林。每一个节点按贪心分裂,最终生成的树包含多层,这就相当于一个特征组合的过程。
  • 理论上来说如果可以无限制的生成决策树,GBDT就可以无限逼近由所有训练样本所组成的目标拟合函数,从而达到减小误差的目的,同样这种思想后来也被运用在了ResNet残差神经网络上。
  • 在推荐系统中,我们使用GBDT算法来优化和提高个性化推荐的准确性。通过GBDT算法对用户历史行为数据进行建模和学习,可以很容易地学习到用户的隐式特征(例如品味,偏好,消费能力等)。另外,GBDT算法可以自动选择重要的特征,对离散型和连续型特征进行处理(如缺失值填充、离散化等),为特征工程提供更好的支持。

GBDT+LR算法

  • 在推荐系统中使用GBDT+LR结合算法主要是用于处理点击率预估,这个也是facebook在2014年发表的论文Practical Lessons from Predicting Clicks on Ads at Facebook。根据点击率预估的结果进行排序,所以GBDT+LR用于排序层。
  • GBDT+LR架构图
    在这里插入图片描述
  • 整个模型主要分为两个步骤,上面的GBDT和下面的LR。主要分为五个步骤:
    1. GBDT训练:使用GBDT对原始数据进行训练并生成特征。在训练过程中,每棵树都是基于前一棵树的残差进行构建。这样,GBDT可以逐步减少残差,生成最终的目标值。
    2. 特征转换:使用GBDT生成的特征进行转换。这些特征是树节点的输出,每个特征都对应于一个叶子节点。在转换过程中,每个叶子节点都会被转换为一个新的特征向量,代表这个叶子节点与其他节点的相对位置,并将这些特征向量连接起来形成新的训练集(用于下一步LR)。
    3. 特征归一化:对生成的特征进行归一化处理,确保不同维度的特征在训练过程中具有相等的权重。
    4. LR训练:使用LR对转换后的特征进行二分类或回归。特征向量被送入LR模型中进行训练,以获得最终的分类模型。在训练过程中,使用梯度下降法来更新模型参数,以最小化损失函数,损失函数的选择取决于分类问题的具体情况。
    5. 模型预测:训练完成后,使用LR模型对新的数据进行预测。GBDT+LR模型将根据特征生成函数和逻辑回归模型预测新数据的类别或值。
  • GBDT + LR优缺点
    • 优点:
      提高模型预测精度
      模型具有鲁棒性和可扩展性
      具有良好的可解释性
    • 缺点:建模复杂度高
      训练时间计算成本高
      对异常值和噪声数据敏感

GBDT + LR简单代码实现

import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn.preprocessing import OneHotEncoder

# 读取数据集
# 数据集地址https://grouplens.org/datasets/movielens/
ratings = pd.read_csv("../../data/ml-1m/ratings.dat", sep="::", header=None, names=["user_id", "movie_id", "rating", "timestamp"], encoding='ISO-8859-1', engine="python")
movies = pd.read_csv("../../data/ml-1m/movies.dat", sep="::", header=None, names=["movie_id", "title", "genres"], encoding='ISO-8859-1', engine="python")
# 将两个数据集根据movie_id合并,并去掉timestamp,title
data = pd.merge(ratings, movies, on="movie_id").drop(columns=["timestamp", "title"])
# print(data.head(50))

# genres字段转换为多个二值型变量(使用pandas的get_dummies函数)
genres_df = data.genres.str.get_dummies(sep="|")
# genres_df = pd.get_dummies(data['genres'])
# print(genres_df.head(50))
data = pd.concat([data, genres_df], axis=1).drop(columns=["genres"])
# print(data.head(50))

# 提取出用于训练 GBDT 模型和 LR 模型的特征和标签
features = data.drop(columns=["user_id", "movie_id", "rating"])
# print(features.head(50))
label = data['rating']

# 划分训练集和测试集(8,2开)
split_index = int(len(data) * 0.8)
train_x, train_y = features[:split_index], label[:split_index]
test_x, test_y = features[split_index:], label[split_index:]

# 训练GBDT模型
n_estimators=100
gbdt_model = lgb.LGBMRegressor(n_estimators=n_estimators, max_depth=5, learning_rate=0.1)
gbdt_model.fit(train_x, train_y)
gbdt_train_leaves = gbdt_model.predict(train_x, pred_leaf=True)
gbdt_test_leaves = gbdt_model.predict(test_x, pred_leaf=True)

# 将GBDT输出的叶子节点ID转换为one-hot编码的特征
one_hot = OneHotEncoder()
one_hot_train = one_hot.fit_transform(gbdt_train_leaves).toarray()
one_hot_test = one_hot.fit_transform(gbdt_test_leaves).toarray()

# 训练LR模型
lr_model = LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
                 intercept_scaling=1, l1_ratio=None, max_iter=100,
                 multi_class='auto', n_jobs=None, penalty='l2',
                 random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                 warm_start=False)
lr_model.fit(one_hot_train, train_y)

# 在测试集上评估模型性能
y_pred = lr_model.predict(one_hot_test)
print(f"Accuracy: {accuracy_score(test_y, y_pred)}")
print(f"Precision: {precision_score(test_y, y_pred, average='macro')}")
print(f"Recall: {recall_score(test_y, y_pred, average='macro')}")
print(f"F1-Score (macro): {f1_score(test_y, y_pred, average='macro')}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504750.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

指针篇章-(冒泡排序详解)

冒泡排序 图解 tmp图解 内容图解 每次循环的次数减少 for循环详解 冒泡排序是一种简单的排序算法,它重复地遍历要排序的数列, 一次比较两个元素,如果它们的顺序错误就把它们交换过来。 遍历数列的工作是重复地进行直到没有再需要交换&…

读算法的陷阱:超级平台、算法垄断与场景欺骗笔记05_共谋(中)

1. 默许共谋 1.1. 又称寡头价格协调(Oligopolistic Price Coordination)或有意识的平行行为(Conscious Parallelism) 1.1.1. 在条件允许的情况下,它会发生在市场集中度较高的行业当中 1.1.…

运维随录实战(13)之docker搭建mysql集群(pxc)

了解 MySQL 集群之前,先看看单节点数据库的弊病 大型互联网程序用户群体庞大,所以架构需要特殊设计。单节点数据库无法满足大并发时性能上的要求。单节点的数据库没有冗余设计,无法满足高可用。单节点 MySQL无法承载巨大的业务量,数据库负载巨大常见 MySQL 集群方案 Re…

STM32_3-1点亮LED灯与蜂鸣器发声

STM32之GPIO GPIO在输出模式时可以控制端口输出高低电平,用以驱动Led蜂鸣器等外设,以及模拟通信协议输出时序等。 输入模式时可以读取端口的高低电平或电压,用于读取按键输入,外接模块电平信号输入,ADC电压采集灯 GP…

【C语言】linux内核ip_generic_getfrag函数

一、讲解 这个函数ip_generic_getfrag是传输层用于处理分段和校验和的一个辅助函数,它通常用在IP层当需要从用户空间拷贝数据构建成网络数据包时。这个函数的实现提供了拷贝数据和进行校验和计算(如果需要的话)的功能。函数的参数解释如下&a…

JVM知识整体学习

前言:本篇没有任何建设性的想法,只是我很早之前在学JVM时记录的笔记,只是想从个人网站迁移过来。文章其实就是对《深入理解JVM虚拟机》的提炼,纯基础知识,网上一搜一大堆。 一、知识点脑图 本文只谈论HotSpots虚拟机。…

利用华为CodeArts持续交付项目演示流程

软件开发生产线(CodeArts)是面向开发者提供的一站式云端平台,即开即用,随时随地在云端交付软件全生命周期,覆盖需求下发、代码提交、代码检查、代码编译、验证、部署、发布,打通软件交付的完整路径&#xf…

前端实现跨域的六种解决方法

本专栏是汇集了一些HTML常常被遗忘的知识,这里算是温故而知新,往往这些零碎的知识点,在你开发中能起到炸惊效果。我们每个人都没有过目不忘,过久不忘的本事,就让这一点点知识慢慢渗透你的脑海。 本专栏的风格是力求简洁…

高吞吐SFTP连接池设计方案

背景 在现代的数据驱动环境中,安全文件传输协议(SFTP)扮演着至关重要的角色,它提供了一种安全、可靠的文件传输方式。我们目前项目是一个大型数据集成平台,跟上下游有很多文件对接是通过SFTP协议,当需要处…

Speech Processing (LASC11158)

大纲 PHON – phonetics and phonology 1. Phonetics and Representations of Speech2. Acoustics of Consonants and VowelsSIGNALS – signal processing, with a focus on speech signals 3. Digital Speech Signals4. the Source-Filter ModelTTS – text-to-speech synth…

基本计算器II

文章目录 题目解析算法解析算法模拟第一步 第二步第三步第四步第五步第六步最后一步 代码 题目解析 题目链接 我们先来看一下题目这个题目的意思很明确就是给你一个算数式让你计算结果并返回并且给了很多辅助条件来帮助你。 算法解析 那么我们来看看这个题目有哪些做法&…

基于机器学习的工业用电量预测完整代码数据

视频讲解: 毕业设计:算法+系统基于机器学习的工业用电量预测完整代码数据_哔哩哔哩_bilibili 界面展示: 结果分析与展示: 代码: from sklearn import preprocessing import random from sklearn.model_selection import train_test_split from sklearn.preprocessing…

【LGR-176-Div.2】[yLCPC2024] 洛谷 3 月月赛 I(A~C and G<oeis>)

[yLCPC2024] A. dx 分计算 前缀和提前处理一下区间和&#xff0c;做到O&#xff08;1&#xff09;访问就可以过。 #include <bits/stdc.h> //#define int long long #define per(i,j,k) for(int (i)(j);(i)<(k);(i)) #define rep(i,j,k) for(int (i)(j);(i)>(k);…

【数学】【组合数学】1830. 使字符串有序的最少操作次数

作者推荐 视频算法专题 本博文涉及知识点 数学 组合数学 LeetCode1830. 使字符串有序的最少操作次数 给你一个字符串 s &#xff08;下标从 0 开始&#xff09;。你需要对 s 执行以下操作直到它变为一个有序字符串&#xff1a; 找到 最大下标 i &#xff0c;使得 1 < i…

深入理解 Webpack 热更新原理:提升开发效率的关键

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

基于php的用户登录实现(v2版)(持续迭代)

目录 版本说明 数据库连接 登录页面&#xff1a;login.html 登录处理实现&#xff1a;login.php 用户欢迎页面&#xff1a;welcome.php 密码修改页面&#xff1a;change_password.html 修改执行&#xff1a;change_password.php 用户注册页面&#xff1a;register.html …

报错Importing ArkTS files to JS and TS files is not allowed. <etsLint>

ts文件并不支持导入ets文件&#xff0c;为了方便开发应用卡片&#xff0c;entryformAbility创建的时候默认是ts文件&#xff0c;这里只需要把ts文件改成ets便可以轻松的导入所需要的ets即可 我创建了一个鸿蒙开发的交流群&#xff0c;喜欢的鸿蒙朋友可以扫码或者写群号&#xf…

指针详解(从基础到入门)

一、什么是指针 在计算机科学中&#xff0c;指针是编程语言中的一个对象&#xff0c;利用地址&#xff0c;它直接指向存在电脑储存器中另一个地方的值。由于通过地址能找到所需的变量单元&#xff0c;可以说&#xff0c;指针指向该变量单元。因此&#xff0c;将地址形象化地称…

如何使用Hexo搭建个人博客

文章目录 如何使用Hexo搭建个人博客环境搭建连接 Github创建 Github Pages 仓库本地安装 Hexo 博客程序安装 HexoHexo 初始化和本地预览 部署 Hexo 到 GitHub Pages开始使用发布文章网站设置更换主题常用命令 插件安装解决成功上传github但是web不更新不想上传文章处理方式链接…

python处理geojson为本地shp文件

一.成果展示 二.环境 我是在Anaconda下的jupyter notebook完成代码的编写&#xff0c;下面是我对应的版本号&#xff0c;我建议大家在这个环境下编写&#xff0c;因为在下载gdal等包的时候会更方便。 二.参考网站 osgeo.osr module — GDAL documentation osgeo.ogr module …