【深度学习】Pytorch项目实战-基于协同过滤实现物品推荐系统

news2025/2/22 13:48:13

一、推荐系统的了解

1. 定义

推荐系统是一个信息过滤系统,旨在为用户提供个性化的内容推荐。它利用用户的历史行为、偏好以及其他相关数据来推测用户可能感兴趣的项目或信息。推荐系统广泛应用于电子商务、社交媒体、流媒体服务等领域,帮助用户发现商品、电影、音乐、文章等。

2. 推荐系统的基本类型

推荐系统主要可以分为以下几类:

2.1 基于内容的推荐(Content-based Filtering)
  • 基于用户过去喜欢的物品的特征,推荐具有相似特征的新物品。
  • 例如,如果用户在看电影时对科幻类电影表现出偏好,系统会推荐其他科幻电影。
  • 使用特征提取技术(如词袋模型、TF-IDF等)来分析物品内容。
2.2 协同过滤推荐(Collaborative Filtering)

基于用户与其他用户之间的互动和偏好,推荐相似用户喜欢的物品。
(1)有两种主要的协同过滤方法:

  • 用户协同过滤:寻找与目标用户相似的用户,推荐这些用户喜爱的物品。
  • 物品协同过滤:寻找与目标物品相似的物品,推荐用户已喜欢的物品。

优点是无需了解物品的具体内容,只需关注用户行为。

2.3 混合推荐(Hybrid Methods):
  • 结合多种推荐方法来产生更准确和强大的推荐结果。
  • 例如,可以结合内容过滤和协同过滤,以弥补各自的不足。

3. 推荐系统在实际应用中面临一些挑战

(1)冷启动问题:新用户或新项目没有足够的数据来生成推荐。
(2)用户隐私:如何在不泄漏用户隐私的情况下收集和使用数据。
(3)多样性与新颖性:避免过于集中于用户过去的偏好,提供更多样化和新颖的推荐。
(4)数据稀疏性:特别是在大规模用户和物品的情况下,数据稀疏会影响推荐质量。

二、推荐系统项目实战

在本例中,使用矩阵分解(Matrix Factorization) 方法来实现协同过滤。这种方法通过将用户-物品交互矩阵分解为两个低维矩阵(用户嵌入和物品嵌入),从而预测用户对未评分物品的偏好。以下是实现步骤:

1. 数据准备

我们需要一个用户-物品交互数据集。例如:
用户 ID
物品 ID
评分(或点击次数)
示例数据:

import pandas as pd
# 创建模拟数据
data = {
    "user_id": [0, 0, 1, 1, 2, 2, 3, 3],
    "item_id": [0, 1, 0, 2, 1, 2, 0, 1],
    "rating": [5, 3, 4, 2, 5, 1, 3, 4]
}
df = pd.DataFrame(data)
print(df)

输出:

   user_id  item_id  rating
0        0        0       5
1        0        1       3
2        1        0       4
3        1        2       2
4        2        1       5
5        2        2       1
6        3        0       3
7        3        1       4

2. 数据预处理

我们需要将用户 ID 和物品 ID 转换为连续的索引,并创建训练数据集。
数据预处理代码:

from torch.utils.data import Dataset, DataLoader
class RatingDataset(Dataset):
    def __init__(self, df):
        self.users = df["user_id"].values
        self.items = df["item_id"].values
        self.ratings = df["rating"].values
    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return {
            "user_id": self.users[idx],
            "item_id": self.items[idx],
            "rating": self.ratings[idx]
        }
# 创建数据集和数据加载器
dataset = RatingDataset(df)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

3. 模型设计

我们使用矩阵分解方法,将用户和物品映射到低维嵌入空间,并通过点积计算预测评分。
模型代码:

import torch
import torch.nn as nn
class MatrixFactorization(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(MatrixFactorization, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
    def forward(self, user_ids, item_ids):
        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        ratings = (user_embeds * item_embeds).sum(dim=1)
        return ratings

4. 训练模型

定义损失函数和优化器,并训练模型。
训练代码:

# 初始化模型、损失函数和优化器
num_users = df["user_id"].nunique()
num_items = df["item_id"].nunique()
embedding_dim = 8
model = MatrixFactorization(num_users, num_items, embedding_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        user_ids = batch["user_id"]
        item_ids = batch["item_id"]
        ratings = batch["rating"]
        # 前向传播
        predicted_ratings = model(user_ids, item_ids)
        loss = criterion(predicted_ratings, ratings.float())
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

5. 模型评估

我们可以使用测试数据集评估模型性能,或者直接预测用户对未评分物品的偏好。
预测代码:

# 测试预测
test_user_id = torch.tensor([0])  # 用户 ID
test_item_id = torch.tensor([2])  # 物品 ID
predicted_rating = model(test_user_id, test_item_id)
print(f"Predicted rating for user {test_user_id.item()} and item {test_item_id.item()}: {predicted_rating.item():.4f}")

6. 推荐物品

根据预测评分,为用户推荐评分最高的物品。
推荐代码:

def recommend_items(model, user_id, num_items, top_k=3):
    item_ids = torch.arange(num_items)
    user_ids = torch.full_like(item_ids, user_id)
    predicted_ratings = model(user_ids, item_ids)
    # 获取评分最高的物品
    top_items = torch.topk(predicted_ratings, top_k).indices
    return top_items.tolist()
# 为用户 0 推荐物品
recommended_items = recommend_items(model, user_id=0, num_items=num_items, top_k=3)
print(f"Recommended items for user 0: {recommended_items}")

三、总结

3.1 实现推荐系统核心步骤

  • 数据准备:收集用户-物品交互数据。
  • 数据预处理:将数据转换为 PyTorch 数据集。
  • 模型设计:使用矩阵分解方法构建推荐模型。
  • 模型训练:定义损失函数和优化器,训练模型。
  • 模型评估:测试模型性能,预测用户对物品的评分。
  • 推荐物品:根据预测评分生成推荐列表。

通过上述步骤,你可以快速实现一个基于 PyTorch 的推荐系统,并根据需求进一步扩展功能。

3.2 扩展方向

  • 多模态推荐:结合文本、图像等信息提升推荐效果。
  • 深度学习模型:使用神经协同过滤(NeuMF)或 Transformer 模型。
  • 在线学习:支持实时更新用户行为数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2303440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

空字符串““、空白字符串“ “和 null 三者的区别

空字符串、空白字符串和 null 三者的区别表格: 类型定义示例长度是否有值空字符串字符串长度为 0,但不是 null,即存在一个有效的空字符串对象。""0有值(空值)空白字符串字符串包含空格、制表符等空白字符&a…

在mfc中使用自定义三维向量类和计算多个三维向量的平均值

先添加一个普通类, Vector3.h, // Vector3.h: interface for the Vector3 class. // //#if !defined(AFX_VECTOR3_H__53D34D26_95FF_4377_BD54_57F4271918A4__INCLUDED_) #define AFX_VECTOR3_H__53D34D26_95FF_4377_BD54_57F4271918A4__INCLUDED_#if _MSC_VER > 1000 #p…

多线程和并发篇

多线程和并发篇 创建一个对象时底层汇编指令实现步骤(cpu可能会进行指令重排序):一、二、三级缓存的实现:并发编程三要素:线程的五大状态:创建线程的三种方式:线程的特征和状态:Thre…

【3.5JavaScript】JavaScript字符串对象

文章目录 1.获取字符串长度2.大小写转换3.获取某一个字符4.截取字符串5.替换字符串6.分割字符串7.检索字符串位置8.例题:统计某一个字符的个数 在 JavaScript 中,对象是非常重要的知识点。对象分为两种:一种是 ”自定义对象“,另…

路由基本配置

学习目标 • 根据拓扑图进行网络布线。 • 清除启动配置并将路由器重新加载为默认状态。 • 在路由器上执行基本配置任务。 • 配置并激活以太网接口。 • 测试并检验配置。 • 思考网络实施方案并整理成文档。 任务 1:网络布线 使用适当的电缆类型连接网络设备。…

windows上vscode cmake工程搭建

安装vscode插件: 1.按装fastc(主要是安装MinGW\mingw64比较方便) 2.安装C,cmake,cmake tools插件 3.准备工作完成之后,按F1,选择cmake:Quick Start就可以创建一个cmake工程。 4.设置Cmake: G…

VUE3+TS+element-plus项目从0开始入门 - 创建项目、认识基本结构

文章目录 写在前面1、创建vue3项目npm create vuelatestnpm i 2、项目结构.vscodevue3结构a、项目树结构b、package.jsonc、tsconfig.jsond、index.htmld、srce、main.tsf、App.vue 写在前面 开前请自行下载vs code、node.js, 在vs code里面安装Vue - Official插件。本文使用的…

shared_ptr 不析构的问题记录

片段1: 片段2: 你们猜 哪个有问题 ?

原生稀疏注意力机制(NSA):硬件对齐且可原生训练的稀疏注意力机制-论文阅读

摘要 长上下文建模对于下一代语言模型至关重要,但标准注意力机制的高计算成本带来了巨大的计算挑战。稀疏注意力提供了一种在保持模型能力的同时提高效率的有前途的方向。本文提出了一种名为 NSA(原生可训练稀疏注意力机制) 的方法&#xff…

从0到1:固件分析

固件分析 0x01 固件提取 1、从厂商官网下载 例如D-link的固件: https://support.dlink.com/resource/products/ 2、代理或镜像设备更新时的流量 发起中间人攻击MITM #启用IP转发功能 echo 1 > /proc/sys/net/ipv4/ip_forward#配置iptables,将目…

conda、anaconda、pip、pytorch、tensorflow有什么区别?

先画一张图,可以大致看出它们的区别和关联: pytorch、tensorflow都是Python的第三方库,相当于封装的代码工具集库,通过import导入使用。这两个都是深度学习框架,用来搭建AI模型什么的,使用范围非常之广&…

项目设置内网 IP 访问实现方案

在我们平常的开发工作中,项目开发、测试完成后进行部署上线。比如电商网站、新闻网站、社交网站等,通常对访问不会进行限制。但是像企业内部网站、内部管理系统等,这种系统一般都需要限制访问,比如内网才能访问等。那么一个网站应…

Vue面试2

1.跨域问题以及如何解决跨域 跨域问题(Cross-Origin Resource Sharing, CORS)是指在浏览器中,当一个资源试图从一个不同的源请求另一个资源时所遇到的限制。这种限制是浏览器为了保护用户安全而实施的一种同源策略(Same-origin p…

合合信息2025届春季校园招聘全面启动!

世界因你而AI,合合信息2025届春季校园招聘启动! 我们是谁? 我们是一家行业领先的人工智能及大数据科技企业 18年深耕AI领域,C端产品与B端服务布局矩阵完善 9.4亿全球累计用户首次下载量💥 来到这里你能得到什么&a…

shiro代码层面追踪

文章目录 环境漏洞分析硬编码 反序列化Gadget构造 环境 环境搭建:https://blog.csdn.net/qq_44769520/article/details/123476443 漏洞分析 硬编码 shiro是对rememberMe这个cookie进⾏反序列化的时候出现了问题。 相应代码 // // Source code recreated from …

虚拟机网络ssh连接失败,没有网络

vscode进行ssh时连接失败,发现是虚拟机没有网络。 虚拟机ping不通www.baidu.com但可以ping通内网 ping 8.8.8.8ping不通。 sudo dhclient -r ens33 sudo dhclient ens33 ip route show可以了。 20250221记录:不知道是不是重启了虚拟机还是咋了&#…

已知点矩阵的三个顶点坐标、行列数和行列的间距,计算得出剩余所有点的坐标

已知点矩阵的三个顶点坐标、行列数和行列的间距,计算得出剩余所有点的坐标 计算矩阵中每个点的坐标代码实现案例图调用验证 计算矩阵中每个点的坐标 给定左上角、左下角和右上角三个点的坐标,以及矩阵的行数、列数、行间距和列间距,我们可以…

go 并发 gorouting chan channel select Mutex sync.One

goroutine // head&#xff1a; 前缀 index&#xff1a;是一个int的指针 func print(head string, index *int) {for i : 0; i < 5; i {// 指针对应的int *indexfmt.Println(*index, head, i)// 暂停1stime.Sleep(1 * time.Second)} }/* Go 允许使用 go 语句开启一个新的运…

深度学习入门--python入门2

以前学的全忘了&#xff0c;现在算是才开始学&#xff0c;有错误&#xff0c;恳请指正。 目录 1.4 Python脚本文件 1.4.1保存为文件 1.4.2 类 1.5 Numpy 1.5.1 导入Numpy 1.5.2 生成Numpy数组 1.5.3 Numpy的算术运算 1.5.4 Numpy的N维数组 1.5.5 广播 1.5.6 访问元素…

题海拾贝:【枚举】P2010 [NOIP 2016 普及组] 回文日期

Hello大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《数据结构与算法之美》、《题海拾贝》 欢迎点赞&#xff0c;关注&#xff01; 1、题…