41目标检测数据集
import os
import pandas as pd
import torch
import torchvision
import matplotlib.pylab as plt
from d2l import torch as d2l
# 数据集下载链接
# http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
# 读取数据集
#@save
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = '../data/banana-detection/'
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
# 将 img_name 列设置为索引,以便后续操作中根据图片名称索引标签。
csv_data = csv_data.set_index('img_name')
images, targets = [], [] # images 用于存储图像,targets 用于存储标签。
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
# 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
# 其中所有图像都具有相同的香蕉类(索引为0)
targets.append(list(target))
# 将 targets 列表转换为 PyTorch 张量,并增加一个维度(通过 unsqueeze(1))。
# 对标签进行归一化处理(除以 256)。
return images, torch.tensor(targets).unsqueeze(1) / 256 # 增加维度以匹配其他张量的形状
# 图像的小批量的形状为(批量大小、通道数、高度、宽度)
# 标签的小批量的形状为(批量大小,m,5),其中m是数据集的任何图像中边界框可能出现的最大数量。
#@save
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
#@save
def load_data_bananas(batch_size):
"""加载香蕉检测数据集"""
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
batch_size, shuffle=True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
batch_size)
return train_iter, val_iter
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
# print(batch[0].shape, batch[1].shape)
# torch.Size([32, 3, 256, 256]) torch.Size([32, 1, 5])
# 效果演示
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
# batch[0] 是包含图像数据的张量,形状为 (batch_size, channels, height, width)
# batch[0][0:10] 选择前 10 个图像。
# .permute(0, 2, 3, 1) 将张量的维度重新排列变为 (batch_size, height, width, channels)
# / 255 将像素值归一化到 [0, 1] 之间
# 图像的像素值通常在0到255之间。如果不进行归一化,像素值直接使用原始范围。
# 图像库在显示图像时,需要将像素值映射到一个合理的范围内。
# 在0到1范围内时,显示库可以更好地处理和展示这些图像。
axes = d2l.show_images(imgs, 2, 5, scale=2)
# d2l.show_images 是一个用于显示多张图像的函数。
# imgs 是预处理后的图像张量。
# 2, 5 指定了图像将被显示为 2 行 5 列的网格。
# scale=2 指定了图像的缩放比例。
# batch[1]是包含图像标签的张量torch.Size([32, 1, 5])
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
# d2l.show_bboxes 是一个用于在图像上绘制边界框的函数。
# ax 是当前图像的坐标轴。
# label[0][1:5] 提取标签中的边界框坐标(标签格式为 [class, x_min, y_min, x_max, y_max])。
# * edge_size 将边界框坐标缩放到图像的实际尺寸。
# colors=['w'] 指定边界框的颜色为白色。
plt.show()
运行结果: