疯狂交互学习的BM3推荐算法(论文复现)
本文所涉及所有资源均在传知代码平台可获取
文章目录
- 疯狂交互学习的BM3推荐算法(论文复现)
- 多模态推荐系统
- 优点
- 示例
- 对比学习
- 什么是对比学习?
- 关键思想
- 优点
- 自监督学习
- 什么是自监督学习?
- 优点
- 实现自监督学习的方法
- 解决方案
- 框架图
- 损失函数(每一步都是自监督对比)
- 图重构损失
- 模态间对齐损失
- 模态内特征遮蔽损失
- 实验分析
- 环境部署
- 数据集配置
- 代码运行
- 运行截图
- 代码分析
- Loss分析
- 参数代码分析
- 相关文件作用分析
- 代码算法复现结果
- 创新思路方向
- 总结
多模态推荐系统
什么是多模态推荐系统?
多模态推荐系统是一种利用多种不同类型的数据源(例如文本、图像、视频、音频等)来进行推荐的系统。传统的推荐系统通常只依赖于单一模态的数据,例如用户的评分或点击行为,而多模态推荐系统则结合了来自多个模态的信息,从而可以提供更准确和个性化的推荐。
优点
- 提高推荐准确性:通过结合多种数据源,可以更全面地了解用户的偏好。
- 丰富的用户体验:多模态数据可以为用户提供更多样化的推荐内容。
- 处理冷启动问题:在用户数据不足的情况下,可以利用其他模态的数据进行推荐。
示例
假设我们有一个电商平台,用户在平台上浏览和购买商品。我们可以使用以下多模态数据来构建推荐系统:
- 文本:商品的描述和用户的评论。
- 图像:商品的图片。
- 行为:用户的点击和购买记录。
对比学习
什么是对比学习?
对比学习是一种自监督学习的方法,通过学习样本之间的相似性和差异性来学习数据的有用表示。目标是使得相似的样本在表示空间中更接近,不相似的样本更远离。
关键思想
- 正样本对 (Positive Pairs):具有相似特征的样本对。
- 负样本对 (Negative Pairs):具有不同特征的样本对。
- 损失函数:通过最小化正样本对之间的距离,最大化负样本对之间的距离来训练模型。
优点
- 无需大量标注数据:对比学习可以在无监督环境中工作。
- 提升特征表达能力:通过对比学习,模型可以学习到更有辨别力的特征。
自监督学习
什么是自监督学习?
自监督学习是一种无监督学习的方法,通过生成伪标签来进行训练。模型利用自身的数据生成训练信号,而不是依赖外部的标签数据。自监督学习的目标是通过设计预训练任务,使模型能够学习到数据的有用表示。
优点
- 减少对标签数据的依赖:不需要大量的人工标注数据。
- 学习到通用特征:通过预训练任务,模型可以学习到适用于多个下游任务的通用特征。
实现自监督学习的方法
常见的自监督学习方法包括:
- 图像领域:通过图像旋转、遮挡、拼图等任务来生成伪标签。
- 文本领域:通过词汇预测、句子排序等任务来生成伪标签。
论文问题提出
- 除了用户-项目交互图之外,现有的最先进的方法通常使用辅助图(例如,用户-用户或项目-项目关系图),以增强所学习的用户和/或项目的表示。这些表示通常使用图卷积网络在辅助图上传播和聚合,这在计算和存储器方面可能非常昂贵,特别是对于大型图。
- 现有的多模态推荐方法通常利用贝叶斯个性化排名(BPR)损失中随机抽样的否定示例来指导用户/项目表示的学习,这增加了大型图上的计算成本,并且还可能将噪声监督信号带入训练过程。
解决方案
- 自监督学习的应用:
BM3 提出了一个新的自监督学习模型,不需要使用负样本或复杂的图增强技术。这简化了现有的自监督学习框架,减少了模型参数。
- Dropout 增强机制:
通过 dropout 增强生成用户和项目的对比视图,而不是通过图或图像增强。这种设计减少了内存和计算成本。
- 多模态对比损失函数:
设计了一个专门用于多模态推荐的对比损失函数,该函数在重建用户-项目交互图的同时对齐不同模态之间的特征,并减少来自同一模态的不同增强视图之间的差异。
框架图
该框架图展示了BM3模型的结构,包括几个关键部分。首先是"Backbone Network"(骨干网络),它接收用户和物品的ID嵌入,并生成初始嵌入表示 huh**u 和 hih**i。然后,这些嵌入与物品的视觉特征和文本特征通过投影网络 fvf**v 和 ftf**t 进行处理,生成图像和文本的嵌入表示 hvh**v 和 hth**t。接下来,“Contrastive View Generator”(对比视图生成器)通过增强技术生成这些嵌入的对比视图(例如 hu*h*u、 hi*h*i、 hv*h*v、 ht*h*t),并应用于三个损失函数。“Graph Reconstruction Loss” LrecLrec 通过对比用户和物品嵌入及其对比视图来增强嵌入表示的鲁棒性和泛化能力;“Inter-modality Feature Alignment Loss” LalignLalign 通过对比不同模态(例如图像和文本)的嵌入和对比视图,促进跨模态的一致性;“Intra-modality Feature Masked Loss” LmaskLmask 通过对比同一模态内部的嵌入和对比视图,进一步增强单模态的鲁棒性。最终,这些损失函数的加权和形成了整体的多模态对比损失 LL,优化模型以提升推荐系统的性能。
损失函数(每一步都是自监督对比)
图重构损失
Lrec=−(cos(h**u,h**i′)+cos(h**i,h**u′))
-huh**u是用户的嵌入表示。
-hih**i是项目的嵌入表示。
-hu′h**u′是用户的对比视图嵌入表示。
-hi′h**i′是项目的对比视图嵌入表示。
- 假设用户uu对项目ii有正反馈,那么huh**u和hih**i应该有较高的相似度。
- 通过对比学习,如果huh**u与hi′h**i′(项目的对比视图)也有较高的相似度,这表明模型对项目特征的变化(如视图变化、噪声)具有鲁棒性。
- 新用户u′u′可能没有与很多项目交互过,但如果hu′h**u′(用户的对比视图)与某些项目的嵌入hih**i保持相似性,那么模型可以根据hu′h**u′推荐相关的项目ii。
- 类似地,新项目i′i′可能没有很多用户交互数据,但如果hi′h**i′(项目的对比视图)与某些用户的嵌入huh**u保持相似性,那么模型可以根据hi′h**i′推荐给相关的用户uu。
- 如果模型只学习huh**u和hih**i的相似性,可能会导致模型只记住某些用户-项目对,而无法泛化到其他用户-项目对。
- 通过对称的损失,即huh**u和hi′h**i′以及hih**i和hu′h**u′的相似性,模型必须学习更广泛的特征,从而减少模式崩溃的风险。
- 增强鲁棒性:通过对比学习,模型需要在不同的增强视图之间保持一致性,这使得模型对噪声和变动具有更强的鲁棒性。用户的嵌入和项目的对比视图之间的一致性可以防止模型过拟合到特定的用户-项目对。
- 促进泛化能力:通过对比用户和项目的对比视图嵌入,模型能够学习到更通用的特征表示。这使得模型在面对新的数据或未见过的用户-项目对时,仍然能够保持较好的性能。
- 减少模式崩溃:在对比学习中,如果只关注正例对的相似性,可能会导致模式崩溃(模式崩溃指的是模型只记住了特定的模式而未能学习到通用的特征)。通过对称的对比视图嵌入损失,可以有效防止模式崩溃。
模态间对齐损失
相当于Item是标签,这些Text和Image是特征,相互学习的过程,把Text赋予标签信息,然后在Item里面增加更多的Text和Image的特征信息,同时由于Dropout可以保证学习的不崩溃
-
统一性和稳定性
- 项目(item)的嵌入表示相对于用户(user)的嵌入表示更为稳定和统一。用户的行为和兴趣可能会随时间和情境发生变化,而项目的特征相对固定,因此使用项目嵌入可以提供更稳定的对齐基础。
-
多视图一致性
- 多视图特征表示 $ h_m’ $ 是从不同模态(如文本、图像、音频等)中提取的。这些特征通常描述的是项目的不同方面,因此使用项目的嵌入来对齐多视图特征可以确保不同模态下的项目特征一致性。
-
提高泛化能力
- 使用项目嵌入来对齐多视图特征可以帮助模型更好地捕捉项目的多模态特性,从而提高模型在处理多模态推荐任务时的泛化能力。这意味着模型可以更好地理解和推荐多种类型的项目,即使在用户行为发生变化时,模型仍然能够提供有效的推荐。
模态内特征遮蔽损失
生成对比视图:通过dropout生成图像和文本的对比视图嵌入 hv′h**v′ 和 ht′h**t′。
计算余弦相似度:计算图像嵌入 hvh**v 和其对比视图嵌入 hv′h**v′ 之间的余弦相似度,以及文本嵌入 hth**t 和其对比视图嵌入 ht′h**t′ 之间的余弦相似度:
计算单模态特征屏蔽损失:将上述两个余弦相似度的负值求和,得到最终的单模态特征屏蔽损失:
- 假设图像 v 经过数据增强(如旋转、裁剪等)后生成对比视图 v’。通过使hv和 hv’ 具有高相似度,模型可以更好地应对图像中的噪声和变动,保证图像嵌入的一致性。
- 如果模型能够在不同的图像视图(如不同的拍摄角度或光照条件)之间保持一致性,那么当遇到新的图像时(如不同场景或对象),模型也能够有效地提取相关特征。
- 如果模型只学习原始图像的特征,可能会过拟合到特定的图像内容或风格。而通过对比原始图像和其增强视图,模型必须学习更通用的图像特征,从而减少模式崩溃的风险。
- 增强单模态的鲁棒性
通过对比学习,模型需要在同一模态内的不同视图之间保持一致性,这使得模型在面对该模态的数据变动时具有更强的鲁棒性。
- 促进模态内的泛化能力
通过对比单模态内的嵌入表示和对比视图,模型能够学习到更通用的特征表示。这使得模型在面对同一模态的新的数据时,仍然能够保持较好的性能。
- 减少单模态的模式崩溃
在单模态的对比学习中,如果只关注单一视图的特征,可能会导致模式崩溃(即模型只记住了特定的模式而未能学习到通用的特征)。通过对比视图嵌入损失,可以有效防止模式崩溃。
实验分析
环境部署
git clone https://github.com/enoche/BM3.git
环境配置
pip install -r requirements.txt
conda install --file requirements.txt
数据集配置
通过这个地址–>dataset下载>
baby
`elec\
sports`这三个数据集,然后将这些文件放入源码的data文件夹下。
代码运行
进入到
src
目录下
cd .\src
然后执行命令,-m 代表模型的名称 -d 代表数据集名称
python main.py -m BM3 -d baby
运行截图
数据集的相关统计数据
模型结构
训练的准确率结果
代码分析
提示:除Loss以及mian文件的分析外,在视频讲解中将会对模型进行一些简单的分析,以便帮助初学者理解模型的搭建,帮助读者进行自我创新。
Loss分析
这部分通过在不同模态不同视图之间进行Loss,可以实现论文中的Loss创新,并且框架中Dropout也在该部分,具体分析均在以下代码的注释当中
def calculate_loss(self, interactions):
# online network
u_online_ori, i_online_ori = self.forward()
t_feat_online, v_feat_online = None, None
if self.t_feat is not None:
t_feat_online = self.text_trs(self.text_embedding.weight)
if self.v_feat is not None:
v_feat_online = self.image_trs(self.image_embedding.weight)
with torch.no_grad(): # 停止梯度更新,这样在下面的操作中不会计算梯度,节省内存和计算资源
u_target, i_target = u_online_ori.clone(), i_online_ori.clone() # 复制在线用户和物品的原始特征向量
u_target.detach() # 分离用户目标特征向量,使其不参与梯度计算
i_target.detach() # 分离物品目标特征向量,使其不参与梯度计算
u_target = F.dropout(u_target, self.dropout) # 对用户目标特征向量应用Dropout,生成用户对比试图
i_target = F.dropout(i_target, self.dropout) # 对物品目标特征向量应用Dropout,生成物品对比试图
if self.t_feat is not None: # 检查时间特征是否存在
t_feat_target = t_feat_online.clone() # 复制时间特征向量
t_feat_target = F.dropout(t_feat_target, self.dropout) # 对时间特征向量应用Dropout,生成image对比试图
if self.v_feat is not None: # 检查image特征是否存在
v_feat_target = v_feat_online.clone() # 复制image特征
v_feat_target = F.dropout(v_feat_target, self.dropout) # 对image特征向量Dropout,生成text对比试图
# 预测用户和物品的在线特征向量
u_online, i_online = self.predictor(u_online_ori), self.predictor(i_online_ori)
# 获取交互数据中的用户和物品索引
users, items = interactions[0], interactions[1]
# 根据用户和物品索引提取相应的在线特征和目标特征
u_online = u_online[users, :] # 提取在线用户特征
i_online = i_online[items, :] # 提取在线物品特征
u_target = u_target[users, :] # 提取目标用户特征
i_target = i_target[items, :] # 提取目标物品特征
# 初始化各类损失为0
loss_t, loss_v, loss_tv, loss_vt = 0.0, 0.0, 0.0, 0.0
if self.t_feat is not None: # 检查时间特征是否存在
t_feat_online = self.predictor(t_feat_online) # 通过预测器更新在线时间特征
t_feat_online = t_feat_online[items, :] # 提取更新后的在线时间特征
t_feat_target = t_feat_target[items, :] # 提取目标时间特征
# 计算时间特征和物品目标特征的余弦相似度损失
loss_t = 1 - cosine_similarity(t_feat_online, i_target.detach(), dim=-1).mean()
# 计算时间特征和目标时间特征的余弦相似度损失
loss_tv = 1 - cosine_similarity(t_feat_online, t_feat_target.detach(), dim=-1).mean()
if self.v_feat is not None: # 检查视觉特征是否存在
v_feat_online = self.predictor(v_feat_online) # 通过预测器更新在线视觉特征
v_feat_online = v_feat_online[items, :] # 提取更新后的在线视觉特征
v_feat_target = v_feat_target[items, :] # 提取目标视觉特征
# 计算视觉特征和物品目标特征的余弦相似度损失
loss_v = 1 - cosine_similarity(v_feat_online, i_target.detach(), dim=-1).mean()
# 计算视觉特征和目标视觉特征的余弦相似度损失
loss_vt = 1 - cosine_similarity(v_feat_online, v_feat_target.detach(), dim=-1).mean()
# 计算用户在线特征和物品目标特征的余弦相似度损失
loss_ui = 1 - cosine_similarity(u_online, i_target.detach(), dim=-1).mean()
# 计算物品在线特征和用户目标特征的余弦相似度损失
loss_iu = 1 - cosine_similarity(i_online, u_target.detach(), dim=-1).mean()
# 返回总损失,包括余弦相似度损失、正则化损失和对比损失
return (loss_ui + loss_iu).mean() + self.reg_weight * self.reg_loss(u_online_ori, i_online_ori) + \
self.cl_weight * (loss_t + loss_v + loss_tv + loss_vt).mean()
参数代码分析
具体分析请看注释
# 创建ArgumentParser对象用于解析命令行参数
parser = argparse.ArgumentParser()
# 添加命令行参数 --model 或 -m,用于指定模型名称,默认值为 'BM3'
parser.add_argument('--model', '-m', type=str, default='BM3', help='name of models')
# 添加命令行参数 --dataset 或 -d,用于指定数据集名称,默认值为 'baby'
parser.add_argument('--dataset', '-d', type=str, default='baby', help='name of datasets')
# 定义包含GPU配置信息的字典
config_dict = {
'gpu_id': 0,
}
# 解析命令行参数,将结果存储在 args 对象中
args, _ = parser.parse_known_args()
# 调用 quick_start 函数,传递模型名称、数据集名称、配置字典以及是否保存模型的标志
quick_start(model=args.model, dataset=args.dataset, config_dict=config_dict, save_model=True)
相关文件作用分析
BM3/
├── data/ # 数据目录
│ ├── baby/ # 婴儿数据目录
│ ├── clothing/ # 服装数据目录
│ └── sports/ # 运动数据目录
│
├── src/ # 源代码目录
│ ├── common/ # 公共模块目录
│ ├── configs/ # 配置目录
│ ├── log/ # 日志目录
│ ├── models/ # 模型目录
│ │ └── bm3.py # BM3模型代码
│ └── utils/ # 工具目录
│ └── main.py # 主程序代码
│
├── trained-models-logs/ # 训练模型日志目录
│
├── .gitignore # git忽略文件
├── LICENSE # 许可证文件
├── README.md # 项目说明文件
└── requirements.txt # 项目依赖文件
代码算法复现结果
提示:由于领域特殊性,其他预测数据无法展示,只能对模型预测的相关准确率指标数据进行分享,请见谅,同时由于不同服务器以及相应python包版本存在差异,会存在复现结果不尽相同,这是正常现象。该结果在RTX 4090得到。
Category | N@10 | N@20 | R@10 | R@20 |
---|---|---|---|---|
Baby | 0.0559 | 0.0880 | 0.0296 | 0.0383 |
Sports | 0.0646 | 0.0978 | 0.0345 | 0.0435 |
Electronics | 0.0434 | 0.0643 | 0.0247 | 0.0301 |
创新思路方向
通过对现有数据集进行统计我们可以发现,现在大部分数据集都存在交互稀疏性的问题
这样的稀疏性将会严重影响模型的性能以及准确率,这是一个可以着重解决的问题
目前在用户与物品的交互信息存在大量的噪声,许多交互并不能表明用户对于该物品存在兴趣,并且由时序问题,对于在不同的时间段,同一个用户的兴趣点可能发生变化,在前的大量的交互可能对改变后的兴趣产生影响,可以引入Tempary进行学习
总结
BM3模型是一种创新的多模态推荐系统,它通过自监督学习框架,利用dropout增强机制和多模态对比损失函数,在不使用负样本或复杂图增强技术的情况下,有效地对齐不同模态的特征并减少内部视图的差异。该模型通过图重构损失、模态间对齐损失和模态内特征遮蔽损失,增强了模型的鲁棒性、泛化能力,并减少了模式崩溃的风险。实验结果表明,BM3在多个数据集上都取得了优异的性能。
通过对现有数据集的统计分析,我们发现交互稀疏性和噪声是推荐系统领域的主要挑战。读者可以针对这部分问题进行相应的解决创新
文章代码资源点击附件获取