Seaborn中热力图的绘制方法
seaborn中绘制热力图使用的是sns.heatmap()函数:
sns.heatmap(data,vmin,vmax,cmap,center,robust,annot,fmt=‘.2g’,annot_kws,linewidths=0,linecolor=‘white’,cbar,cbar_kws,cbar_ax,square,xticklabels=‘auto’,yticklabels=‘auto’,mask,ax,**kwargs,)
关键常用参数说明:
data:要绘制热力图的数据集,可以是DataFrame、数组或列表等。
vmin,vmax:设置热力图颜色映射的取值范围,vmin最小值,vmax最大值。
cmap:可选参数,用于指定颜色映射。默认值为"viridis",表示使用viridis颜色映射;可以设置为一个颜色映射名称或颜色映射对象。
center:可选参数,用于指定颜色映射的中心值。默认值为None,表示使用数据的中心值;可以设置为一个数值,表示中心值。
robust:可选参数,用于指定是否使用鲁棒性的颜色映射。默认值为False,表示不使用鲁棒性;可以设置为True,表示使用鲁棒性。
annot:可选参数,用于在每个单元格中显示数值。默认值为False,表示不显示数值;可以设置为True,表示显示数值;可以设置为一个布尔数组或与data形状相同的数组,用于指定要显示的数值。
fmt:可选参数,用于指定数值的格式。默认值为".2g",表示使用科学计数法;可以设置为其他格式字符串,例如"%d"表示整数,“%.2f"表示保留两位小数。
annot_kws:可选参数,用于传递给annot参数的其他参数,例如字体大小、颜色等。
linewidths:可选参数,用于指定每个单元格之间的边框线宽度。默认值为0,表示不显示边框线;可以设置为一个浮点数,表示边框线的宽度。
linecolor:可选参数,用于指定每个单元格之间的边框线颜色。默认值为"white”,表示白色;可以设置为一个颜色名称或颜色代码。
cbar:可选参数,用于指定是否显示颜色条。默认值为True,表示显示颜色条;可以设置为False,表示不显示颜色条。
cbar_kws:可选参数,用于传递给颜色条的其他参数,例如颜色条的标签、方向等。
square:可选参数,用于指定是否将每个单元格绘制为正方形。默认值为False,表示不绘制正方形;可以设置为True,表示绘制正方形。
xticklabels、yticklabels:可选参数,用于指定x轴和y轴的标签。默认值为"auto",表示自动绘制标签;可以设置为一个标签列表或布尔数组,用于指定要显示的标签。
ax:可选参数,用于指定绘制热力图的坐标轴对象。默认值为None,表示使用当前坐标轴。
**kwargs:用于传递其他绘图参数,例如图像的标题、标签、颜色等。
案例说明
【例5.18】使用seaborn绘制银行板块股票和IT设备制造板块股票交易额热力图
代码如下:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
# 设置为默认字体
plt.rcParams['font.family'] = 'SimHei'
# 显示负数
plt.rcParams['axes.unicode_minus'] = False
# 导入数据
df = pd.read_excel("2023年一季度A股日线行情.xlsx")
# 将日期列转化为日期格式
df["trade_date"] = df["trade_date"].astype("str").apply(lambda x:x[:4]+"-"+x[4:6]+"-"+x[6:])
# 将日期列转换为日期类型,并设置为索引列
df['trade_date'] = pd.to_datetime(df['trade_date'])
df.set_index('trade_date', inplace=True)
# 筛选2023年3月的行情数据
start_date = '2023-01-03'
end_date = '2023-03-31'
df = df.loc[start_date:end_date]
# 创建股票池列表-深圳交易所银行类版块股票池,共9支
stock_list1 = ['000001.SZ','001227.SZ','002142.SZ','002807.SZ','002839.SZ',
'002936.SZ','002948.SZ','002958.SZ','002966.SZ']
# 创建股票池列表-深圳交易所IT设备制造类版块股票池,共9支
stock_list2 = ['000066.SZ','000977.SZ','002180.SZ','002197.SZ','002236.SZ',
'002351.SZ','002415.SZ','002528.SZ','002866.SZ']
# 从DataFrame中将股票池中的交易数据筛选出来
stock_1 = df[df['ts_code'].isin(stock_list1)]
stock_2 = df[df['ts_code'].isin(stock_list2)]
# 相比交易量,交易额更能反映出某个版块当前的交易热度
# 按周统计银行类版块的每只股票每周的交易额
stock1_amout_weekly = []
for i in range(len(stock_list1)):
df_temp = stock_1[stock_1['ts_code']==stock_list1[i]]
temp_a = df_temp.resample('W').sum()
temp_b = temp_a['amount'].to_list()
stock1_amout_weekly.append(temp_b)
# 同理,按周统计IT设备制造类版块的每只股票每周的交易额
stock2_amout_weekly = []
for i in range(len(stock_list2)):
df_temp = stock_2[stock_2['ts_code']==stock_list2[i]]
temp_a = df_temp.resample('W').sum()
temp_b = temp_a['amount'].to_list()
stock2_amout_weekly.append(temp_b)
# 创建一个1x2的子图布局
fig, ax = plt.subplots(1, 2, figsize=(15, 6))
# 绘制银行类版块股票的热力图
im1 = sns.heatmap(stock1_amout_weekly, ax = ax[0])
# 设置图表标题和标签
ax[0].set_title('Stock of Banks Trading Amount Heatmap')
# 横坐标设置为周数,纵坐标为各银行
ax[0].set_xlabel("Weeks")
ax[0].set_ylabel("Banks")
# 同理,绘制IT设备制造类版块股票的热力图,并设置标题、横纵坐标。
im2 = sns.heatmap(stock2_amout_weekly, ax = ax[1])
ax[1].set_title('Stock of IT-Vehicles Trading Amount Heatmap')
ax[1].set_xlabel("Weeks")
ax[1].set_ylabel("Banks")
# 显示图表
plt.show()
代码运行效果如下图所示: