最近在进行绘图时,遇到了matplotlib画散点图,并根据目标列的类别来设置颜色区间的问题,但是实现的过程较为艰辛。
文章目录
- 一、数据准备
- 二、第一次尝试(失败及其原因)
- 2.1 失败
- 2.2 原因
- 三、第二次尝试(成功)
- 四、总结—plt.scatter()函数的参数
- 4.1 全部常见的参数
- 4.2 其中的c参数
- 4.2.1 使用单一颜色值
- 4.2.2 使用颜色序列
- 4.2.3 使用数值映射
一、数据准备
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
data = pd.read_excel('./ch2-iris.xlsx')
data.head()
我们希望画出sepal length和sepal width之间的散点图,并根据class列的类别来分类。
二、第一次尝试(失败及其原因)
2.1 失败
# 提取 sepal length 和 sepal width 数据
x_axis = data[' sepal length']
y_axis = data['sepal width']
# 提取 class 列的类别作为颜色
colors = data['class']
# 绘制散点图
plt.scatter(x_axis, y_axis, c=colors)
# 设置图标题和坐标轴标签
plt.title('Sepal Length vs Sepal Width')
plt.xlabel('Sepal Length')
plt.ylabel('Sepal Width')
# 显示图像
plt.show()
此时代码报错:
ValueError Traceback (most recent call last)
File d:\Anaconda\envs\PyTorch\lib\site-packages\matplotlib\axes\_axes.py:4375, in Axes._parse_scatter_color_args(c, edgecolors, kwargs, xsize, get_next_color_func)
4374 try: # Is 'c' acceptable as PathCollection facecolors?
-> 4375 colors = mcolors.to_rgba_array(c)
4376 except (TypeError, ValueError) as err:
File d:\Anaconda\envs\PyTorch\lib\site-packages\matplotlib\colors.py:487, in to_rgba_array(c, alpha)
486 else:
--> 487 rgba = np.array([to_rgba(cc) for cc in c])
489 if alpha is not None:
File d:\Anaconda\envs\PyTorch\lib\site-packages\matplotlib\colors.py:487, in (.0)
486 else:
--> 487 rgba = np.array([to_rgba(cc) for cc in c])
489 if alpha is not None:
File d:\Anaconda\envs\PyTorch\lib\site-packages\matplotlib\colors.py:299, in to_rgba(c, alpha)
298 if rgba is None: # Suppress exception chaining of cache lookup failure.
--> 299 rgba = _to_rgba_no_colorcycle(c, alpha)
300 try:
File d:\Anaconda\envs\PyTorch\lib\site-packages\matplotlib\colors.py:374, in _to_rgba_no_colorcycle(c, alpha)
373 return c, c, c, alpha if alpha is not None else 1.
--> 374 raise ValueError(f"Invalid RGBA argument: {orig_c!r}")
...
'Iris-virginica' 'Iris-virginica' 'Iris-virginica' 'Iris-virginica'
'Iris-virginica' 'Iris-virginica' 'Iris-virginica' 'Iris-virginica'
'Iris-virginica' 'Iris-virginica' 'Iris-virginica' 'Iris-virginica'
'Iris-virginica' 'Iris-virginica' 'Iris-virginica' 'Iris-virginica'
'Iris-virginica' 'Iris-virginica' 'Iris-virginica' 'Iris-virginica']
2.2 原因
根据错误信息,ValueError: Invalid RGBA argument,似乎是在尝试将 ‘class’ 列的值作为颜色传递给 scatter() 函数时出现了错误。
scatter() 函数要求颜色参数是合法的 RGBA(红绿蓝透明度)值,但 ‘class’ 列的值是字符串类型,不符合颜色参数的要求,因此导致了错误。
如果想将 ‘class’ 列的值用作颜色分类,可以将其转换为数值类型或使用其他方法来将字符串值映射为颜色。
三、第二次尝试(成功)
以下是一种可能的解决方法,将 ‘class’ 列的字符串值映射为颜色编码:
x_axis = data[' sepal length']
y_axis = data['sepal width']
color_map = {'Iris-setosa': 'red', 'Iris-versicolor': 'green', 'Iris-virginica': 'blue'}
# 使用 map 方法将 'class' 列的值映射为颜色编码
colors = data['class'].map(color_map)
c = data['class'].map(color_map)
plt.scatter(x_axis, y_axis, c = c)
plt.xlabel('Sepal Length')
plt.ylabel('Sepal Width')
plt.title('Scatter plot of Sepal Length vs Sepal Width')
plt.show()
运行结果如下:
这样,‘class’ 列的不同类别将分别用红色、绿色和蓝色表示在散点图中。您可以根据需要自定义颜色编码映射字典,以适应您的数据和可视化需求。
四、总结—plt.scatter()函数的参数
4.1 全部常见的参数
plt.scatter() 函数是 Matplotlib 库中用于绘制散点图的函数,它的常用参数如下:
plt.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, edgecolors=None, *, plotnonfinite=False, data=None, **kwargs)
其中,最常用的参数包括:
x
:指定散点图的 x 坐标,可以是一个 NumPy 数组、Pandas Series 或 Python 列表;y
:指定散点图的 y 坐标,可以是一个 NumPy 数组、Pandas Series 或 Python 列表;s
:指定散点的大小,可以是一个数值或表示大小的数组,用于控制散点的尺寸;c
:指定散点的颜色,可以是一个数值或表示颜色的数组,用于控制散点的颜色;marker
:指定散点的标记样式,默认为 ‘o’,可以使用常见的标记样式,如 ‘o’、‘s’、‘d’、‘^’ 等;cmap
:指定颜色映射(colormap),用于将数值映射为颜色,一般与 c 参数一起使用;alpha
:指定散点的透明度,取值范围为 0 到 1,0 表示完全透明,1 表示完全不透明;linewidths
:指定散点边界的宽度;edgecolors
:指定散点边界的颜色。
除了上述参数外,plt.scatter() 函数还可以接受其他关键字参数(**kwargs)用于进一步自定义散点图的样式和属性。
注意:参数的具体用法和取值范围可以参考 Matplotlib 官方文档或使用 help(plt.scatter) 查看详细说明。
4.2 其中的c参数
在 plt.scatter() 函数中,c 参数用于指定散点图中的颜色。c 可以接受不同类型的输入:
- 单一颜色值:可以使用字符串表示颜色,如 ‘red’、‘green’、‘blue’ 等,表示所有的散点都使用相同的颜色;
- 颜色序列:可以使用列表、数组或 Series 对象,表示每个散点的颜色。例如,可以传入一个长度与散点数目相等的列表,其中每个元素表示对应散点的颜色;
- 数值映射:可以使用数值映射函数,将数值映射到颜色。例如,可以传入一个与散点数目相等的数值序列,然后使用 cmap 参数指定颜色映射,将数值映射为对应的颜色。
以下是 plt.scatter() 函数中 c 参数的一些常见用法:
4.2.1 使用单一颜色值
import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
colors = 'red' # 指定所有散点的颜色为红色
plt.scatter(x, y, c=colors)
plt.show()
4.2.2 使用颜色序列
import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
colors = ['red', 'green', 'blue', 'yellow', 'purple'] # 指定每个散点的颜色
plt.scatter(x, y, c=colors)
plt.show()
4.2.3 使用数值映射
import matplotlib.pyplot as plt
import numpy as np
x = np.random.rand(50)
y = np.random.rand(50)
colors = np.random.rand(50) # 生成随机数值序列作为颜色映射
plt.scatter(x, y, c=colors, cmap='viridis') # 使用 viridis 色彩映射
plt.colorbar() # 显示颜色映射条
plt.show()
注意:c 参数只接受长度与 x 和 y 相同的输入序列,用于指定每个散点的颜色。如果 c 参数输入的序列长度不符合要求,将会引发错误。