声明
本文章为个人学习使用,版面观感若有不适请谅解,文中知识仅代表个人观点,若出现错误,欢迎各位批评指正。
三十二、目标检测和边界框
import torch
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
def show_images(imgs, titles=None):
plt.imshow(imgs)
backend_inline.set_matplotlib_formats('svg')
plt.rcParams['figure.figsize'] = (6.5, 3.5)
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.title(titles)
plt.show()
img = plt.imread('E:\\cat\\catdog.jpg')
show_images(img, titles='原图')
def box_corner_to_center(boxes):
"""从(左上,右下)转换到(中间,宽度,高度)"""
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
boxes = torch.stack((cx, cy, w, h), axis=-1)
return boxes
def box_center_to_corner(boxes):
"""从(中间,宽度,高度)转换到(左上,右下)"""
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
boxes = torch.stack((x1, y1, x2, y2), axis=-1)
return boxes
# bbox是边界框的英文缩写
dog1_bbox, dog2_bbox = [55.0, 265.0, 252.0, 590.0], [400.0, 18.0, 656.0, 590.0]
cat1_bbox, cat2_bbox = [231.0, 188.0, 443.0, 595.0], [650.0, 226.0, 905.0, 590.0]
boxes = torch.tensor((dog1_bbox, dog2_bbox, cat1_bbox, cat2_bbox))
print('测试函数正确性 : ', box_center_to_corner(box_corner_to_center(boxes)) == boxes)
def bbox_to_rect(bbox, color):
# 将边界框(左上x,左上y,右下x,右下y)格式转换成 matplotlib 格式:
# ((左上x,左上y),宽,高)
return plt.Rectangle(
xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
fill=False, edgecolor=color, linewidth=2)
fig = plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(cat1_bbox, 'red'))
fig.axes.add_patch(bbox_to_rect(cat2_bbox, 'red'))
fig.axes.add_patch(bbox_to_rect(dog2_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(dog1_bbox, 'blue'))
plt.axis('off')
plt.suptitle('标记后')
plt.show()
三十三、目标检测数据集
import os
import pandas as pd
import torch
import torchvision
from matplotlib import pyplot as plt
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
try:
img = numpy(img)
except:
pass
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
def bbox_to_rect(bbox, color):
return plt.Rectangle(
xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
fill=False, edgecolor=color, linewidth=2)
def show_bboxes(axes, bboxes, labels=None, colors=None):
def make_list(obj, default_values=None):
if obj is None:
obj = default_values
elif not isinstance(obj, (list, tuple)):
obj = [obj]
return obj
numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)
labels = make_list(labels)
colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])
for i, bbox in enumerate(bboxes):
color = colors[i % len(colors)]
rect = bbox_to_rect(numpy(bbox), color)
axes.add_patch(rect)
if labels and len(labels) > i:
text_color = 'k' if color == 'w' else 'w'
axes.text(rect.xy[0], rect.xy[1], labels[i],
va='center', ha='center', fontsize=9, color=text_color,
bbox=dict(facecolor=color, lw=0))
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = 'E:\\banana-detection'
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
# 这里的 target 包含(类别,左上角 x,左上角 y,右下角 x,右下角 y),
# 其中所有图像都具有相同的香蕉类(索引为0)
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
def load_data_bananas(batch_size):
""" 加载香蕉检测数据集 """
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
batch_size, shuffle=True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
batch_size)
return train_iter, val_iter
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
print(f'(批量大小、通道数、高度、宽度) : {batch[0].shape}\n'
f'(批量大小、数据集的任何图像中边界框可能出现的最大数量、5) : {batch[1].shape}')
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
show_bboxes(ax, [label[0][1:5] * edge_size], colors=['r'])
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.suptitle('数据集展示')
plt.show()
文中部分知识参考:B 站 —— 跟李沐学AI;百度百科