目录
- 1 存在的问题
- 2 可能的解决方案
- 3 最终的解决方案
- 3.1 方案一(我已弃用)
- 3.2 方案二(基于方案一)
- 3.3 方案三(基于方案一)
1 存在的问题
李沐老师提供的读取香蕉数据集的函数如下:
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'images', f'{img_name}'))
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
执行到如下代码时报错:
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'images', f'{img_name}'))
报错内容为:
RuntimeError: No such operator image::read_file
2 可能的解决方案
- 博客一:报错 RuntimeError: No such operator image::read_file
- 它认为是文件路径中
/
和\\
的问题,可惜我全部改为反斜杠后并未解决问题 - 博客二:解决 RuntimeError: No such operator image::read_file
- 它认为是 torchvision 版本的问题,可是重装 torchvision 又会需要重装 torch 等包
- 这是因为 torch 等包的版本要和 torchvision 的版本对应,我认为代价太大
3 最终的解决方案
3.1 方案一(我已弃用)
放弃使用 torchvision.io.read_image()
,换成其他函数来做:
- 采用
Image.open()
函数读取图片(之前看小土堆的视频用过) - 将读取到的图片转换为
tensor
(这是torchvision.io.read_image()
函数的作用之一)
简而言之,我们换成其他代码来实现
torchvision.io.read_image()
函数的作用。
① 增加需要使用到的包:
from PIL import Image
② 增加将图片转换为 tensor
的类:
🥲 让 AI 帮我写的,我是真的写不了一点
class ToTensorNoNorm(torchvision.transforms.ToTensor):
def __call__(self, pic):
return torch.tensor(super().__call__(pic) * 255, dtype=torch.uint8)
Q:为什么不直接使用
torchvision.transforms.ToTensor
类的实例?
A:因为它在将图片转换为tensor
时会进行归一化,而torchvision.io.read_image()
函数是没有这个作用的,所以我们定义一个继承自torchvision.transforms.ToTensor
类但不做归一化的新类。
③ 修改原函数:
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'label.csv')
csv_data = pd.read_csv(csv_fname)
# 将 img_name 列设置为索引列
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
# 修改部分(三行代码)
image = Image.open(os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'images', f'{img_name}'))
transform = ToTensorNoNorm()
images.append(transform(image))
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
🥳 代码运行成功:
3.2 方案二(基于方案一)
方案二还是基于方案一的思路,但是不再需要自定义将图片转换为 tensor
的类了,因为我发现李沐老师在后面的代码中做了归一化:
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
也就是说,
torchvision.transforms.ToTensor
类的归一化不再变得鸡肋。
① 增加需要使用到的包:
from PIL import Image
② 修改原函数:
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir,
'bananas_train' if is_train else 'bananas_val',
'label.csv')
csv_data = pd.read_csv(csv_fname)
# 将 img_name 列设置为索引列
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
# 修改部分(三行代码)
image = Image.open(os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'images', f'{img_name}'))
transform = torchvision.transforms.ToTensor()
images.append(transform(image))
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
③ 去除后面代码中的归一化:
imgs = (batch[0][0:10].permute(0, 2, 3, 1))
简而言之,
torchvision.transforms.ToTensor
类会对图片做归一化,后面就不需要再做了。
🥳 代码运行成功:
3.3 方案三(基于方案一)
第二天想到的方法
李沐老师的语义分割代码又使用到了 torchvision.io.read_image
函数,我不想每次都要定义一个将图片转换为 tensor 同时又不做归一化的类。
torchvision.io.read_image
函数其实可以被替换为如下代码,以实现相同的效果:
transform = torchvision.transforms.ToTensor() # 实例化 ToTensor 类
image = Image.open(image_dir) # 读取图片
image = (transform(image) * 255).to(torch.uint8) # 转 tensor 但不归一化
注意:图片归一化后的数值是 float 型的小数,即使乘了 255 还是 float 型,需要转换为 integer 型。否则报错 “Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers)”,也就是说,针对 RGB 数据,要么是 [0, 1] 之间的 float 型,要么是 [0, 255] 之间的 integer 型。
替换到语义分割一节的代码上:
transform = torchvision.transforms.ToTensor()
image = Image.open(os.path.join(voc_dir, 'JPEGImages', f'{fname}.jpg'))
label = Image.open(os.path.join(voc_dir, 'SegmentationClass' ,f'{fname}.png')).convert('RGB')
features.append((transform(image) * 255).to(torch.uint8))
labels.append((transform(label) * 255).to(torch.uint8))
🥳 代码运行成功: