余弦相似度算法进行客户流失分类预测

news2024/10/2 10:40:23

余弦相似性是一种用于计算两个向量之间相似度的方法,常被用于文本分类和信息检索领域。具体来说,假设有两个向量A和B,它们的余弦相似度可以通过以下公式计算:

其中,dot_product(A, B)表示向量A和B的点积,norm(A)和norm(B)分别表示向量A和B的范数。如果A和B越相似,它们的余弦相似度就越接近1,反之亦然。

数据集

我们这里用的演示数据集来自一个datacamp:

这个数据集来自一家伊朗电信公司,每一行代表一个客户一年的时间。除了客户流失标签,还有客户活动的信息,比如呼叫失败和订阅时长等等。我们最后要预测的是这个客户是否流失,也就是一个二元分类的问题。

数据集如下:

 import pandas as pd
 df = pd.read_csv("data/customer_churn.csv")
 

我们先区分训练和验证集:

 from sklearn.model_selection import train_test_split
 
 # split the dataframe into 70% training and 30% testing sets
 train_df, test_df = train_test_split(df, test_size=0.3)

SVM

为了进行对比,我们先使用SVM做一个基础模型

 fromsklearn.svmimportSVC
 fromsklearn.metricsimportclassification_report, confusion_matrix
 
 # define the range of C and gamma values to try
 c_values= [0.1, 1, 10, 100]
 gamma_values= [0.1, 1, 10, 100]
 
 # initialize variables to store the best result
 best_accuracy=0
 best_c=None
 best_gamma=None
 best_predictions=None
 
 # loop over the different combinations of C and gamma values
 forcinc_values:
     forgammaingamma_values:
         # initialize the SVM classifier with RBF kernel, C, and gamma
         clf=SVC(kernel='rbf', C=c, gamma=gamma, random_state=42)
 
         # fit the classifier on the training set
         clf.fit(train_df.drop('Churn', axis=1), train_df['Churn'])
 
         # predict the target variable of the test set
         y_pred=clf.predict(test_df.drop('Churn', axis=1))
 
         # calculate accuracy and store the result if it's the best so far
         accuracy=clf.score(test_df.drop('Churn', axis=1), test_df['Churn'])
         ifaccuracy>best_accuracy:
             best_accuracy=accuracy
             best_c=c
             best_gamma=gamma
             best_predictions=y_pred
 
 # print the best result and the confusion matrix
 print(f"Best result: C={best_c}, gamma={best_gamma}, accuracy={best_accuracy:.2f}")
 print("Confusion matrix:")
 print(confusion_matrix(test_df['Churn'], best_predictions))

可以看到支持向量机得到了87%的准确率,并且很好地预测了客户流失。

余弦相似度算法

这段代码使用训练数据集来计算类之间的余弦相似度。

 importpandasaspd
 fromsklearn.metrics.pairwiseimportcosine_similarity
 
 # calculate the cosine similarity matrix between all rows of the dataframe
 cosine_sim=cosine_similarity(train_df.drop('Churn', axis=1))
 
 # create a dataframe from the cosine similarity matrix
 cosine_sim_df=pd.DataFrame(cosine_sim, index=train_df.index, columns=train_df.index)
 
 # create a copy of the train_df dataframe without the churn column
 train_df_no_churn=train_df.drop('Churn', axis=1)
 
 # calculate the mean cosine similarity for class 0 vs. class 0
 class0_cosine_sim_0=cosine_sim_df.loc[train_df[train_df['Churn'] ==0].index, train_df[train_df['Churn'] ==0].index].mean().mean()
 
 # calculate the mean cosine similarity for class 0 vs. class 1
 class0_cosine_sim_1=cosine_sim_df.loc[train_df[train_df['Churn'] ==0].index, train_df[train_df['Churn'] ==1].index].mean().mean()
 
 # calculate the mean cosine similarity for class 1 vs. class 1
 class1_cosine_sim_1=cosine_sim_df.loc[train_df[train_df['Churn'] ==1].index, train_df[train_df['Churn'] ==1].index].mean().mean()
 
 # display the mean cosine similarities for each pair of classes
 print('Mean cosine similarity (class 0 vs. class 0):', class0_cosine_sim_0)
 print('Mean cosine similarity (class 0 vs. class 1):', class0_cosine_sim_1)
 print('Mean cosine similarity (class 1 vs. class 1):', class1_cosine_sim_1)

下面是它们的余弦相似度:

然后我们生成一个DF

 importpandasaspd
 
 # create a dictionary with the mean and standard deviation values for each comparison
 data= {
     'comparison': ['Class 0 vs. Class 0', 'Class 0 vs. Class 1', 'Class 1 vs. Class 1'],
     'similarity_mean': [class0_cosine_sim_0, class0_cosine_sim_1, class1_cosine_sim_1],
 }
 
 # create a Pandas DataFrame from the dictionary
 df=pd.DataFrame(data)
 
 df=df.set_index('comparison').T
 
 
 # print the resulting DataFrame
 print(df)

下面就是把这个算法应用到训练数据集上。我取在训练集上创建一个sample_churn_0,其中包含10个样本以的距离。

 # create a DataFrame containing a random sample of 10 points where Churn is 0
 sample_churn_0=train_df[train_df['Churn'] ==0].sample(n=10)

然后将它交叉连接到test_df。这将使test_df扩充为10倍的行数,因为每个测试记录的右侧有10个示例记录。

 importpandasaspd
 
 # assume test_df and sample_churn_0 are your dataframes
 
 # add a column to both dataframes with a common value to join on
 test_df['join_col'] =1
 sample_churn_0['join_col'] =1
 
 # perform the cross-join using merge()
 result_df=pd.merge(test_df, sample_churn_0, on='join_col')
 
 # drop the join_col column from the result dataframe
 result_df=result_df.drop('join_col', axis=1)

现在我们对交叉连接DF的左侧和右侧进行余弦相似性比较。

 importpandasaspd
 fromsklearn.metrics.pairwiseimportcosine_similarity
 
 # Extract the "_x" and "_y" columns from the result_df DataFrame, excluding the "Churn_x" and "Churn_y" columns
 df_x=result_df[[colforcolinresult_df.columnsifcol.endswith('_x') andnotcol.startswith('Churn_')]]
 df_y=result_df[[colforcolinresult_df.columnsifcol.endswith('_y') andnotcol.startswith('Churn_')]]
 
 # Calculate the cosine similarities between the two sets of vectors on each row
 cosine_sims= []
 foriinrange(len(df_x)):
     cos_sim=cosine_similarity([df_x.iloc[i]], [df_y.iloc[i]])[0][0]
     cosine_sims.append(cos_sim)
 
 # Add the cosine similarity values as a new column in the result_df DataFrame
 result_df['cos_sim'] =cosine_sims

然后用下面的代码提取所有的列名:

 x_col_names = [col for col in result_df.columns if col.endswith('_x')]

这样我们就可以进行分组并获得每个test_df记录的平均余弦相似度(目前重复10次),然后在grouped_df中,我们将其重命名为x_col_names:

 grouped_df = result_df.groupby(result_df.columns[:14].tolist()).agg({'cos_sim': 'mean'})
 
 grouped_df = grouped_df.rename_axis(x_col_names).reset_index()
 
 grouped_df.head()

最后我们计算这10个样本的平均余弦相似度。

在上面步骤中,我们计算的分类相似度的df是这个:

我们就使用这个数值作为分类的参考。首先,我们需要将其交叉连接到grouped_df(与test_df相同,但具有平均余弦相似度):

 cross_df = grouped_df.merge(df, how='cross')
 cross_df = cross_df.iloc[:, :-1]

结果如下:

最后我们得到了3列:Class 0 vs. Class 0, and Class 0 vs. Class 1,然后我们需要得到类之间的差别:

 cross_df['diff_0'] = abs(cross_df['cos_sim'] - df['Class 0 vs. Class 0'].iloc[0])
 cross_df['diff_1'] = abs(cross_df['cos_sim'] - df['Class 0 vs. Class 1'].iloc[0])

预测的代码如下:

 # Add a new column 'predicted_churn'
 cross_df['predicted_churn'] = ''
 
 # Loop through each row and check the minimum difference
 for idx, row in cross_df.iterrows():
     if row['diff_0'] < row['diff_1']:
         cross_df.at[idx, 'predicted_churn'] = 0
     else:
         cross_df.at[idx, 'predicted_churn'] = 1

最后我们看看结果:

 grouped_df__2 = cross_df.groupby(['predicted_churn', 'Churn_x']).size().reset_index(name='count')
 grouped_df__2['percentage'] = grouped_df__2['count'] / grouped_df__2['count'].sum() * 100
 
 grouped_df__2.head()

可以看到,模型的准确率为84.25%。但是我们可以看到,他的混淆矩阵看到对于某些预测要比svm好,也就是说它可以在一定程度上解决类别不平衡的问题。

总结

余弦相似性本身并不能直接解决类别不平衡的问题,因为它只是一种计算相似度的方法,而不是一个分类器。但是,余弦相似性可以作为特征表示方法,来提高类别不平衡数据集的分类性能。本文只是作为一个样例还有可以提高的空间。

本文的数据集在这里:

https://avoid.overfit.cn/post/5cd4d22b523c418cb5d716e942a7ed46

如果你有兴趣可以自行尝试。

作者:Ashutosh Malgaonkar

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/484707.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是链接库 | 动态库与静态库

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和…

SPSS如何进行基本统计分析之案例实训?

文章目录 0.引言1.描述性分析2.频数分析3.探索分析4.列联表分析5.比率分析 0.引言 因科研等多场景需要进行绘图处理&#xff0c;笔者对SPSS进行了学习&#xff0c;本文通过《SPSS统计分析从入门到精通》及其配套素材结合网上相关资料进行学习笔记总结&#xff0c;本文对基本统计…

深度学习卷积神经网络学习小结2

简介 经过大约两周左右的学习&#xff0c;对深度学习有了一个初步的了解&#xff0c;最近的任务主要是精读深度学习方向的文献&#xff0c;由于搭建caffe平台失败而且比较耗费时间就没有再尝试&#xff0c;所以并没有做实践方面的工作&#xff0c;本文只介绍了阅读文献学到的知…

JdbcTemplate常用语句代码示例

目录 JdbcTemplate 需求 官方文档 JdbcTemplate-基本介绍 JdbcTemplate 使用实例 需求说明 创建数据库 spring 和表 monster 创建配置文件 src/jdbc.properties 创建配置文件 src/JdbcTemplate_ioc.xml 创建类JdbcTemplateTest测试是否可以正确得到数据源 配置 J…

《程序员面试金典(第6版)面试题 16.10. 生存人数(前缀和思想)

题目描述 给定 N 个人的出生年份和死亡年份&#xff0c;第 i 个人的出生年份为 birth[i]&#xff0c;死亡年份为 death[i]&#xff0c;实现一个方法以计算生存人数最多的年份。 你可以假设所有人都出生于 1900 年至 2000 年&#xff08;含 1900 和 2000 &#xff09;之间。如果…

Spring源码解读——高频面试题

Spring IoC的底层实现 1.先通过createBeanFactory创建出一个Bean工厂&#xff08;DefaultListableBeanFactory&#xff09; 2.开始循环创建对象&#xff0c;因为容器中的bean默认都是单例的&#xff0c;所以优先通过getBean、doGetBean从容器中查找&#xff0c;如果找不到的…

LeetCode-1003. 检查替换后的词是否有效

题目链接 LeetCode-1003. 检查替换后的词是否有效 题目描述 题解 题解一&#xff08;Java&#xff09; 作者&#xff1a;仲景 题挺难懂的&#xff0c;很绕&#xff0c;然后读懂了就很简单了 就是说本来是一个字符串s&#xff0c;abc三个字符可以随便放在s原本字符串的左边或…

删除游戏-类似打家劫舍

198. 打家劫舍 - 力扣&#xff08;LeetCode&#xff09; 1 熟悉打家劫舍 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被…

java+微信小程序,实现chatgpt聊天小程序

chatgp持续火爆,然鹅会用的人其实挺少的,现在使用异步请求的方式,基本可以实现秒回复。并且还基于webSocket编写了一个微信小程序来进行交互,可以直接使用微信小程序来进行体验。 现在我将所有代码都上传了github(链接在文章结尾),大家可以clone下来,部署到服务器上,真…

shell命令

shell命令 打开文本编辑器(可以使用vi/vim创建文本),新建一个test.sh文件&#xff0c;输入一些代码&#xff0c;第一行为固定写法 #!/bin/bash echo hello word#!是一个约定的标记&#xff0c;他告诉系统这个脚本使用什么解释器执行 shell中注释 1.单行注释使用# 2.多行注释…

在Linux服务器上(非root权限)配置anaconda和pytorch的GPU环境

本人小白一枚&#xff0c;加入了导师的课题组之后使用学校的服务器开始炼丹&#xff0c;但是光是配环境就花了好几天&#xff0c;特此记录下。。。。 选择你趁手的工具 链接远程服务器的终端工具有很多&#xff0c;例如xshell等&#xff0c;我选择是的finalshell 下载教程 【…

敏捷ACP.敏捷估计与规划.Mike Cohn.

第一部分 传统规划失败的原因 vs 敏捷规划有效的原因 传统的项目规划方式往往会让我们失望。要回答-一个 新产品的范围/进度/资源的组合问题&#xff0c;传统规划过程不一定会产生令人非常满意的答案和最终产品。以下- -些论据可以支持这个结论: ●大约2/3的项目会显著超…

Linux设备驱动模型(一)

一、sysfs文件系统 sysfs是一个虚拟文件系统&#xff0c;将内核总的设备对象的链接关系&#xff0c;以文件目录的方式表示出来&#xff0c;并提对设备的供读写接口。 二、kobject kobject是内核中对象表示的基类&#xff0c;可以认为所有的内核对象都是一个kobject kobject单…

Docker 持久化存储 Bind mounts

Docker 持久化存储 Bind mounts Bind mounts 的 -v 与 --mount 区别启动容器基于bind mount挂载到容器中的非空目录只读 bind mountcompose 中使用 bind mount 官方文档&#xff1a;https://docs.docker.com/storage/bind-mounts/ Bind mounts 的 -v 与 --mount 区别 如果使用…

Origin如何使用基础功能?

文章目录 0.引言1.菜单栏2.工具栏 0.引言 因科研等多场景需要进行绘图处理&#xff0c;笔者对Origin进行了学习&#xff0c;本文通过《Origin 2022科学绘图与数据》及其配套素材结合网上相关资料进行学习笔记总结&#xff0c;本文对软件界面基础功能进行阐述。    1.菜单栏 …

【Linux内核解析-linux-5.14.10-内核源码注释】内核源码中宏定义理解

内核宏定义1 这是Linux内核中的start_kernel函数的一部分代码。它的作用是初始化内核的一些基本组件和数据结构。 asmlinkage: 这是一个函数声明修饰符&#xff0c;指示编译器把函数参数放在堆栈中&#xff0c;而不是寄存器中。 __visible: 这是另一个函数声明修饰符&#x…

第二十六章 碰撞体Collision(上)

在游戏世界中&#xff0c;游戏物体之间的交互都是通过“碰撞接触”来进行交互的。例如&#xff0c;攻击怪物则是主角与怪物的碰撞&#xff0c;触发机关则是主角与机关的碰撞。在DirectX课程中&#xff0c;我们也大致介绍过有关碰撞检测的内容。游戏世界中的3D模型的形状是非常复…

生成模型经典算法-VAEGAN(含Python源码例程)

生成模型 文章目录 生成模型1. 概述2. 生成模型典型结构-VAE&GAN2.1 VAE2.1.1 简介2.1.2 模型处理流程 2.2 GAN2.2.1 简介 2.2.2 生成对抗网络要点2.2.3 生成对抗网络的训练准则2.2.4 生成对抗网络模型处理流程 3.生成模型和判别模型在AIGC中的应用3.1 生成模型在AIGC中的应…

【SQL】面试篇之排序和分组练习

1587 银行账户概要 II 1587题目 # Write your MySQL query statement below select name, balance from (select u.account, name, sum(amount) as balancefrom Users uleft join Transactions ton u.account t.accountgroup by u.account ) temp where balance > 10000总…

给定一个文本文件,每行是一条股票信息,写程序提取出所有的股票代码

问题&#xff1a;给定一个文本文件&#xff0c;每行是一条股票信息&#xff0c;写程序提取出所有的股票代码。其中&#xff0c;股票代码规则是&#xff1a;6 位数字&#xff0c; 而且以.SH 或者.SZ 结尾。 文件内容示例&#xff1a; 2020-08-08;平安银行(000001.SZ);15.55;2940…