一、问题描述
在绘制相关分析热力图的时候:
import seaborn as sns
to_corr = ['Age', 'Income', 'Kidhome', 'Teenhome', 'Recency', 'Complain',
'MntWines', 'MntFruits', 'MntMeatProducts', 'MntFishProducts', 'MntSweetProducts',
'MntGoldProds', 'NumDealsPurchases', 'AcceptedCmp1', 'AcceptedCmp2', 'AcceptedCmp3',
'AcceptedCmp4', 'AcceptedCmp5', 'Response', 'NumWebPurchases', 'NumCatalogPurchases',
'NumStorePurchases', 'NumWebVisitsMonth', 'Years_Since_Registration', 'Family_Size',
'Sum_Mnt', 'Num_Accepted_Cmp', 'Num_Total_Purchases']
cmap = sns.diverging_palette(220, 10, as_cmap = True)
matrix = np.triu(data[to_corr].corr())
plt.figure(figsize = (25, 14))
plt.title('Correlation matrix', fontsize = 18)
sns.heatmap(data = [to_corr].corr(), annot = True, fmt = '.1f', vmin = -0.4, center = 0, cmap = cmap, mask = matrix)
plt.show()
- 首先,定义了一个包含待计算相关系数的特征列表to_corr,其中包括了数据集中的多个特征,如年龄(‘Age’)、收入(‘Income’)、家庭成员数量(‘Kidhome’、‘Teenhome’)、购买金额(‘MntWines’、‘MntFruits’、'MntMeatProducts’等)、购买次数(‘NumDealsPurchases’、'NumWebPurchases’等)、注册年限(‘Years_Since_Registration’)等。
- 使用 seaborn 库中的
diverging_palette()
函数生成一个调色板 cmap,用于定义热力图的颜色。 - 利用 numpy 库中的 triu() 函数计算 to_corr 列中的特征之间的相关系数矩阵,并将其赋值给变量matrix。
np.triu()
函数将矩阵的下三角部分置为零,只保留上三角部分,用于在热力图中显示上三角矩阵。 - 创建一个图形对象,设置图形的大小和标题。
- 使用 seaborn 库中的
heatmap()
函数生成热力图,将计算得到的相关系数矩阵作为数据传递给 data 参数,设置annot=True
以在图中显示相关系数的数值,设置fmt='.1f'
以将数值格式化为一位小数,设置vmin=-0.4
和center=0
以调整颜色映射的范围,设置cmap=cmap
以使用之前定义的调色板,设置mask=matrix
以将矩阵的下三角部分遮挡,最后使用 plt.show() 函数显示图形。
通过生成的热力图,可以直观地观察到数据集中各个特征之间的相关性,颜色越深表示相关性越强,颜色越浅表示相关性越弱。这可以帮助您了解不同特征之间的关系,从而在数据分析和建模过程中作出更明智的决策。
但是代码报错:
二、报错原因及改正
这个错误是因为在 sns.heatmap()
函数中,data 参数传入了一个列表 to_corr,而列表没有 corr() 方法,导致了 AttributeError 错误的发生。
corr() 方法用于计算数据的相关系数矩阵,它应该作用于一个数据框(DataFrame)对象,而不是列表。你可以通过从原始数据框中选择需要计算相关系数的列,创建一个新的数据框,然后将其传递给 sns.heatmap() 函数。
以下是修复错误的示例代码:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
to_corr = ['Age', 'Income', 'Kidhome', 'Teenhome', 'Recency', 'Complain',
'MntWines', 'MntFruits', 'MntMeatProducts', 'MntFishProducts', 'MntSweetProducts',
'MntGoldProds', 'NumDealsPurchases', 'AcceptedCmp1', 'AcceptedCmp2', 'AcceptedCmp3',
'AcceptedCmp4', 'AcceptedCmp5', 'Response', 'NumWebPurchases', 'NumCatalogPurchases',
'NumStorePurchases', 'NumWebVisitsMonth', 'Years_Since_Registration', 'Family_Size',
'Sum_Mnt', 'Num_Accepted_Cmp', 'Num_Total_Purchases']
cmap = sns.diverging_palette(220, 10, as_cmap=True)
matrix = np.triu(data[to_corr].corr())
plt.figure(figsize=(25, 14))
plt.title('Correlation matrix', fontsize=18)
sns.heatmap(data=data[to_corr].corr(), annot=True, fmt='.1f', vmin=-0.4, center=0, cmap=cmap, mask=matrix)
plt.show()
在修复后的代码中,data[to_corr].corr()
会计算 to_corr
列中的相关系数矩阵,并作为参数传递给 sns.heatmap()
函数进行绘图。
运行结果如下: