一、说明
Matplotlib是一个强大的Python数据可视化库,可以绘制各种类型的图形,其中包括热图。热图通常用于表现数据的分布和趋势。本文用一个简单的例子,告诉大家用Matplotlib绘制热图的基本操作语句。
二、热图的概念
2.1 基本概念
热图(heatmap)是数据分析的常用方法,通过色差、亮度来展示数据的差异、易于理解。Python在Matplotlib库中,调用imshow()函数实现热图绘制。
参考资料:http://matplotlib.org/users/image_tutorial.html
2.2 热图绘制方法
一般化例子代码
import matplotlib.pyplot as plt
import numpy as np
# 生成随机数据
data = np.random.rand(10, 10)
# 绘制热图
fig, ax = plt.subplots()
im = ax.imshow(data)
# 设置刻度
ax.set_xticks(np.arange(10))
ax.set_yticks(np.arange(10))
# 将刻度标签替换为数组值
ax.set_xticklabels(np.arange(1, 11))
ax.set_yticklabels(np.arange(1, 11))
# 添加颜色条
cbar = ax.figure.colorbar(im, ax=ax)
# 设置图形标题
ax.set_title("Heatmap Example")
# 显示图形
plt.show()
这段代码将随机生成一个10x10的数组,并将其用作热图的数据。然后,我们创建一个图形和一个轴对象,并使用Matplotlib中的imshow函数将数据绘制为一个热图。设置刻度和刻度标签以显示数据的行和列,添加一个颜色条以表示数据范围,并设置图形标题。最后,我们使用show函数显示热图。
三、imshow函数说明
3.1 函数原型
imshow(X, cmap=None, norm=None, aspect=None, interpolation=None, alpha=None, vmin=None, vmax=None, origin=None, extent=None, shape=None, filternorm=1, filterrad=4.0, imlim=None, resample=None, url=None, hold=None, data=None, **kwargs)
3.2 函数参数表
参数名称 | 参数作用 | 备注 |
---|---|---|
X | 二维数组,表示要显示的图像。 X变量存储图像,可以是浮点型数组、unit8数组以及PIL图像,如果其为数组,则需满足一下形状: | 输入 |
cmap=None | 颜色映射。常见的有 hot 从黑平滑过度到红、橙色和黄色的背景色,然后到白色。 | |
interpolation | 插值方式。常见的有 | |
aspect | 图像长宽比。 | |
vmin | 图像的颜色最小值。 | |
vmax : | 图像的颜色最大值。 | |
alpha | 透明度 | |
origin | 坐标轴原点的位置。可以设置为upper 或lower 。 | |
extent | 控制显示的数据范围。可以设置为[xmin, xmax, ymin, ymax] | |
shape | ||
origin : |
| |
extent |
| |
filternorm 和 filterrad : | 用于图像滤波的对象。可以设置为 | |
imlim | 用于指定图像显示范围。 | |
resample : | 用于指定图像重采样方式。 | |
url | 用于指定图像链接。 |
四、imshow使用案例
显示二维高斯分布的blob图:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 资料 https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
mean = [0,0]
cov = [[1,1],[1,2]]
x,y = np.random.multivariate_normal(mean, cov, 10000).T
# 拟合数组维度
data = np.vstack([x, y])
kde = stats.gaussian_kde(data)
# 用一对规则的网络数据进行拟合
xgrid = np.linspace(-3.5, 3.5, 200)
ygrid = np.linspace(-6, 6, 200)
Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
Z = kde.evaluate(np.vstack([Xgrid.ravel(), Ygrid.ravel()]))
# 画出结果图
plt.imshow(Z.reshape(Xgrid.shape),
origin='lower',aspect='auto',
extent=[-3.5, 3.5, -6, 6],cmap='Blues')
plt.xlabel('速度')
plt.ylabel('位置')
plt.show()
运行结果