目录
1、普通绘制热力图
2、坐标轴标签太多,自定义标签显示
3、不显示热图的网格
1、普通绘制热力图
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((7,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.5,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
2、坐标轴标签太多,自定义标签显示
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((100,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.5,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
dx = data.shape[0]/(len(bands_wavelength)-1)
ax.set_xticks([dx*i for i in range(len(bands_wavelength))])
ax.set_yticks([dx*i for i in range(len(bands_wavelength))])
ax.set_xticklabels(bands_wavelength)
ax.set_yticklabels(bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
3、不显示热图的网格
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((100,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
# 修改linewidths为0即可
ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
dx = data.shape[0]/(len(bands_wavelength)-1)
ax.set_xticks([dx*i for i in range(len(bands_wavelength))])
ax.set_yticks([dx*i for i in range(len(bands_wavelength))])
ax.set_xticklabels(bands_wavelength)
ax.set_yticklabels(bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()