✨博客主页:王乐予🎈
✨年轻人要:Living for the moment(活在当下)!💪
🏆推荐专栏:【图像处理】【千锤百炼Python】【深度学习】【排序算法】
目录
- 😺〇、仓库源码
- 😺一、数据集介绍
- 🐶1.1 GitHub原始数据集
- 🐶1.2 GitHub预处理后的数据集
- 🦄1.2.1 简化的绘图文件(.ndjson)
- 🦄1.2.2 二进制文件(.bin)
- 🦄1.2.3 Numpy位图(.npy)
- 🐶1.3 Kaggle数据集
- 😺二、数据集准备
- 😺三、获取png格式图片
- 😺四、训练过程
- 🐶4.1 split_datasets.py
- 🐶4.2 option.py
- 🐶4.3 getdata.py
- 🐶4.4 model.py
- 🐶4.5 train-DDP.py
- 🐶4.6 model_transfer.py
- 🐶4.7 evaluate.py
😺〇、仓库源码
本文所有代码存放在GitHub
仓库中QuickDraw-DDP:欢迎fork
和star
😺一、数据集介绍
Quick Draw 数据集是 345 个类别的 5000 万张图纸的集合,由游戏 Quick, Draw!的玩家贡献。这些图画被捕获为带时间戳的矢量,并标记有元数据,包括要求玩家绘制的内容以及玩家所在的国家/地区。
GitHub数据集地址: 📎The Quick, Draw! Dataset
Kaggle数据集地址:📎Quick, Draw! Doodle Recognition Challenge
Github中提供了两种类型的数据集,分别是 原始数据集 和 预处理后的数据集 。
Google Cloud提供了数据集下载链接:quickdraw_dataset
🐶1.1 GitHub原始数据集
原始数据以按类别分隔的 ndjson
文件的形式提供,格式如下:
键 | 类型 | 说明 |
---|---|---|
key_id | 64位无符号整型 | 所有图形的唯一标识符 |
word | 字符串 | 类别 |
recognized | 布尔值 | 该类别是否被游戏识别 |
timestamp | 日期时间 | 绘制时间 |
countrycode | 字符串 | 玩家所在位置的双字母国家/地区代码 (ISO 3166-1 alpha-2) |
drawing | 字符串 | 一个矢量绘制的 JSON 数组 |
每行包含一个绘图数据,下面是单个绘图的示例:
{
"key_id":"5891796615823360",
"word":"nose",
"countrycode":"AE",
"timestamp":"2017-03-01 20:41:36.70725 UTC",
"recognized":true,
"drawing":[[[129,128,129,129,130,130,131,132,132,133,133,133,133,...]]]
}
drawing
字段格式如下:
[
[ // First stroke
[x0, x1, x2, x3, ...],
[y0, y1, y2, y3, ...],
[t0, t1, t2, t3, ...]
],
[ // Second stroke
[x0, x1, x2, x3, ...],
[y0, y1, y2, y3, ...],
[t0, t1, t2, t3, ...]
],
... // Additional strokes
]
其中x
和y
是像素坐标,t
是自第一个点以来的时间(以毫秒为单位)。由于用于显示和输入的设备不同,原始绘图可能具有截然不同的边界框和点数。
🐶1.2 GitHub预处理后的数据集
🦄1.2.1 简化的绘图文件(.ndjson)
简化了向量,删除了时序信息,并将数据定位和缩放为256x256
区域。数据以ndjson
格式导出,其元数据与raw
格式相同。简化过程是:
- 将绘图与左上角对齐,最小值为 0。
- 统一缩放绘图,最大值为 255。
- 以 1 像素的间距对所有描边重新取样。
- 使用 epsilon 值为 2.0 的Ramer-Douglas-Peucker 算法简化所有笔画。
读取ndjson
文件的代码如下:
# read_ndjson.py
import json
with open('aircraft carrier.ndjson', 'r') as file:
for line in file:
data = json.loads(line)
key_id = data['key_id']
drawing = data['drawing']
# ……
读取aircraft carrier.ndjson
,debug
之后的输出结果如下图所示。可以看到第一行数据包含8个笔触。
🦄1.2.2 二进制文件(.bin)
简化的图纸和元数据也以自定义二进制格式提供,以实现高效的压缩和加载。
读取bin
文件的代码如下:
# read_bin.py
import struct
from struct import unpack
def unpack_drawing(file_handle):
key_id, = unpack('Q', file_handle.read(8))
country_code, = unpack('2s', file_handle.read(2))
recognized, = unpack('b', file_handle.read(1))
timestamp, = unpack('I', file_handle.read(4))
n_strokes, = unpack('H', file_handle.read(2))
image = []
for i in range(n_strokes):
n_points, = unpack('H', file_handle.read(2))
fmt = str(n_points) + 'B'
x = unpack(fmt, file_handle.read(n_points))
y = unpack(fmt, file_handle.read(n_points))
image.append((x, y))
return {
'key_id': key_id,
'country_code': country_code,
'recognized': recognized,
'timestamp': timestamp,
'image': image
}
def unpack_drawings(filename):
with open(filename, 'rb') as f:
while True:
try:
yield unpack_drawing(f)
except struct.error:
break
for drawing in unpack_drawings('nose.bin'):
# do something with the drawing
print(drawing['country_code'])
🦄1.2.3 Numpy位图(.npy)
所有简化的绘图都已渲染为numpy
格式的28x28
灰度位图。这些图像是根据简化的数据生成的,但与绘图边界框的中心对齐,而不是与左上角对齐。
读取npy
文件的代码如下:
# read_npy.py
import numpy as np
data_path = 'aircraft_carrier.npy'
data = np.load(data_path)
print(data)
🐶1.3 Kaggle数据集
在Kaggle竞赛中,使用的数据集为340
个类别。数据格式统一为csv
表格数据。数据集中有5个文件:
- sample_submission.csv - 正确格式的样本提交文件
- test_raw.csv - 矢量格式的测试数据
raw
- test_simplified.csv - 矢量格式的测试数据
simplified
- train_raw.zip - 向量格式的训练数据;每个单词一个 CSV 文件
raw
- train_simplified.zip - 向量格式的训练数据;每个单词一个 CSV 文件
simplified
注:
csv
文件的列title
与ndjson
文件的键名一致。
😺二、数据集准备
本文将使用kaggle
提供的train_simplified
数据集。案例流程包含:
- 将所有类的
csv
格式文件保存为png
图片格式; - 对340个类别的png格式图片各抽取
10000
张用作后续实践; - 对每个类别的10000张数据进行8:1:1的训练集、验证集、测试集的划分;
- 训练模型;
- 模型评估。
😺三、获取png格式图片
使用下面脚本可以将csv数据转为png图片格式保存。
# csv2png.py
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy import interpolate, misc
import matplotlib
matplotlib.use('Agg')
input_dir = 'kaggle/train_simplified'
output_base_dir = 'datasets256'
os.makedirs(output_base_dir, exist_ok=True)
csv_files = [f for f in os.listdir(input_dir) if f.endswith('.csv')] # Retrieve all CSV files from the folder
skipped_files = [] # Record skipped files
for csv_file in csv_files:
csv_file_path = os.path.join(input_dir, csv_file) # Build a complete file path
output_dir = os.path.join(output_base_dir, os.path.splitext(csv_file)[0]) # Build output directory
if os.path.exists(output_dir): # Check if the output directory exists
skipped_files.append(csv_file)
print(f'The directory already exists, skip file: {csv_file}')
continue
os.makedirs(output_dir, exist_ok=True)
data = pd.read_csv(csv_file_path) # Read CSV file
for index, row in data.iterrows(): # Traverse each row of data
drawing = eval(row['drawing'])
key_id = row['key_id']
word = row['word']
img = np.zeros((256, 256)) # Initialize image
fig = plt.figure(figsize=(256/96, 256/96), dpi=96)
for stroke in drawing: # Draw each stroke
stroke_x = stroke[0]
stroke_y = stroke[1]
x = np.array(stroke_x)
y = np.array(stroke_y)
np.interp((x + y) / 2, x, y)
plt.plot(x, y, 'k')
ax = plt.gca()
ax.xaxis.set_ticks_position('top')
ax.invert_yaxis()
plt.axis('off')
plt.savefig(os.path.join(output_dir, f'{word}-{key_id}.png'))
plt.close(fig)
print(f'Conversion completed: {csv_file} the {index:06d}image')
print("The skipped files are:")
for file in skipped_files:
print(file)
需要注意的是:绘图数据有5000万左右,处理时间非常久,建议多开几个脚本运行(PS:代码中添加了文件夹是否存在的判断语句,不用担心会重复写入)。也可以使用
joblib
库多线程加速(玩不好容易宕机,不建议)。
相关文件存储空间大小如下:
- GitHub 预处理后的
ndjson
文件有23G
;- Kaggle 的
train_raw.zip
文件有206G
;- Kaggle 的
train_simplified.zip
文件有23G
;- Kaggle 的
train_simplified
转为256*256
大小的图片有470G
;
如果磁盘空间不足,进行png转化时可以选择128128大小或者6464大小。也可以保存单通道图像。
建议处理完毕之后使用下面的脚本检查一下有没有没处理的类别:
# check_class_num.py
import os
folder = 'datasets256'
subfolders = [f.path for f in os.scandir(folder) if f.is_dir()]
for subfolder in subfolders: # Traverse each subfolders
folder_name = os.path.basename(subfolder) # Get the name of the subfolders
files = [f for f in os.scandir(subfolder) if f.is_file()] # Retrieve all files in the subfolders
image_count = sum(1 for f in files if f.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))) # Calculate the number of images
if image_count == 0: # If the number of images is 0, print out the names of the subfolders and delete them
print(f"There are no images in the subfolders '{folder_name}', deleting them...")
os.rmdir(subfolder)
print(f"subfolders '{folder_name}' deleted")
else:
print(f"Number of images in subfolders: '{folder_name}' : {image_count}")
如果检查到有空文件夹,需要再运行csv2png.py
的代码。
😺四、训练过程
🐶4.1 split_datasets.py
首先要划分数据集,原始数据为png图片格式数据集。
import os
import shutil
import random
original_dataset_path = 'datasets256' # Original dataset path
new_dataset_path = 'datasets' # Divide the dataset path
train_path = os.path.join(new_dataset_path, 'train')
val_path = os.path.join(new_dataset_path, 'val')
test_path = os.path.join(new_dataset_path, 'test')
if not os.path.exists(train_path):
os.makedirs(train_path)
if not os.path.exists(val_path):
os.makedirs(val_path)
if not os.path.exists(test_path):
os.makedirs(test_path)
classes = os.listdir(original_dataset_path) # Get all categories
random.seed(42)
for class_name in classes: # Traverse each category
src_folder = os.path.join(original_dataset_path, class_name) # Source folder path
# Check if the folder for this category already exists under train, val, and test
train_folder = os.path.join(train_path, class_name)
val_folder = os.path.join(val_path, class_name)
test_folder = os.path.join(test_path, class_name)
# If the train, val, and test folders already exist, skip the folder creation section
if os.path.exists(train_folder) and os.path.exists(val_folder) and os.path.exists(test_folder):
# Check if the folder is empty
if os.listdir(train_folder) and os.listdir(val_folder) and os.listdir(test_folder):
print(f"Category {class_name} already exists and is not empty, skip processing.")
continue
# create folder
if not os.path.exists(train_folder):
os.makedirs(train_folder)
if not os.path.exists(val_folder):
os.makedirs(val_folder)
if not os.path.exists(test_folder):
os.makedirs(test_folder)
files = os.listdir(src_folder) # Retrieve all file names under this category
files = files[:10000] # Only retrieve the first 10000 files
random.shuffle(files) # Shuffle file list
total_files = len(files)
train_split_index = int(total_files * 0.8)
val_split_index = int(total_files * 0.9)
train_files = files[:train_split_index]
val_files = files[train_split_index:val_split_index]
test_files = files[val_split_index:]
for file in train_files:
src_file = os.path.join(src_folder, file)
dst_file = os.path.join(train_folder, file)
shutil.copy(src_file, dst_file)
for file in val_files:
src_file = os.path.join(src_folder, file)
dst_file = os.path.join(val_folder, file)
shutil.copy(src_file, dst_file)
for file in test_files:
src_file = os.path.join(src_folder, file)
dst_file = os.path.join(test_folder, file)
shutil.copy(src_file, dst_file)
print("Dataset partitioning completed!")
代码运行完毕之后,datasets
目录下面会出现三个文件夹,分别是train
、val
和test
。
🐶4.2 option.py
定义后续我们需要的一些参数。
import argparse
def get_args():
parser = argparse.ArgumentParser(description='all argument')
parser.add_argument('--num_classes', type=int, default=340, help='image num classes')
parser.add_argument('--loadsize', type=int, default=64, help='image size')
parser.add_argument('--epochs', type=int, default=100, help='all epochs')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
parser.add_argument('--lr', type=float, default=0.001, help='init lr')
parser.add_argument('--use_lr_scheduler', type=bool, default=True, help='use lr scheduler')
parser.add_argument('--dataset_train', type=str, default='./datasets/train', help='train path')
parser.add_argument('--dataset_val', type=str, default="./datasets/val", help='val path')
parser.add_argument('--dataset_test', type=str, default="./datasets/test", help='test path')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='ckpt path')
parser.add_argument('--tensorboard_dir', type=str, default='./tensorboard_dir', help='log path')
parser.add_argument('--resume', type=bool, default=False, help='continue training')
parser.add_argument('--resume_ckpt', type=str, default='./checkpoints/model_best.pth', help='choose breakpoint ckpt')
parser.add_argument('--local-rank', type=int, default=-1, help='local rank')
parser.add_argument('--use_mix_precision', type=bool, default=False, help='use mix pretrain')
parser.add_argument('--test_img_path', type=str, default='datasets/test/zigzag/zigzag-4508464694951936.png', help='choose test image')
parser.add_argument('--test_dir_path', type=str, default='./datasets/test', help='choose test path')
return parser.parse_args()
由于后续将使用DDP单机多卡以及AMP策略进行训练,因此额外加入了local-rank
和use_mix_precision
参数。
🐶4.3 getdata.py
接下来定义数据管道。
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from option import get_args
opt = get_args()
mean = [0.9367, 0.9404, 0.9405]
std = [0.1971, 0.1970, 0.1972]
def data_augmentation():
data_transform = {
'train': transforms.Compose([
transforms.Resize((opt.loadsize, opt.loadsize)),
transforms.ToTensor(), # HWC -> CHW
transforms.Normalize(mean, std)
]),
'val': transforms.Compose([
transforms.Resize((opt.loadsize, opt.loadsize)),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
}
return data_transform
def MyData():
data_transform = data_augmentation()
image_datasets = {
'train': ImageFolder(opt.dataset_train, data_transform['train']),
'val': ImageFolder(opt.dataset_val, data_transform['val']),
}
data_sampler = {
'train': torch.utils.data.distributed.DistributedSampler(image_datasets['train']),
'val': torch.utils.data.distributed.DistributedSampler(image_datasets['val']),
}
dataloaders = {
'train': DataLoader(image_datasets['train'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['train']),
'val': DataLoader(image_datasets['val'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['val'])
}
return dataloaders
class_names =[
'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'airplane', 'alarm clock', 'ambulance', 'angel',
'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn',
'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee',
'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang',
'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus',
'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire',
'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair',
'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie',
'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher',
'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell',
'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger',
'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower',
'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club',
'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter',
'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub',
'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard',
'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion',
'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone',
'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom',
'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush',
'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil',
'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car',
'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain',
'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich',
'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep',
'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail',
'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet',
'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign',
'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set',
'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'tiger',
'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train',
'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine',
'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra',
'zigzag'
]
if __name__ == '__main__':
mena_std_transform = transforms.Compose([transforms.ToTensor()])
dataset = ImageFolder(opt.dataset_val, transform=mena_std_transform)
print(dataset.class_to_idx) # Index for each category
🐶4.4 model.py
定义模型,这里使用mobilenet的small版本。需要将模型的classifier层的输出改为类别数量。
可以使用更多优质的模型对数据集进行训练,例如shufflenet
、squeezenet
等。
import torch.nn as nn
from torchvision.models import mobilenet_v3_small
from torchsummary import summary
from option import get_args
opt = get_args()
def CustomMobileNetV3():
model = mobilenet_v3_small(weights='MobileNet_V3_Small_Weights.IMAGENET1K_V1')
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, opt.num_classes)
return model
if __name__ == '__main__':
model = CustomMobileNetV3()
print(model)
print(summary(model.to(opt.device), (3, opt.loadsize, opt.loadsize), opt.batch_size))
模型结构如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [1024, 16, 32, 32] 432
BatchNorm2d-2 [1024, 16, 32, 32] 32
Hardswish-3 [1024, 16, 32, 32] 0
Conv2d-4 [1024, 16, 16, 16] 144
BatchNorm2d-5 [1024, 16, 16, 16] 32
ReLU-6 [1024, 16, 16, 16] 0
AdaptiveAvgPool2d-7 [1024, 16, 1, 1] 0
Conv2d-8 [1024, 8, 1, 1] 136
ReLU-9 [1024, 8, 1, 1] 0
Conv2d-10 [1024, 16, 1, 1] 144
Hardsigmoid-11 [1024, 16, 1, 1] 0
SqueezeExcitation-12 [1024, 16, 16, 16] 0
Conv2d-13 [1024, 16, 16, 16] 256
BatchNorm2d-14 [1024, 16, 16, 16] 32
InvertedResidual-15 [1024, 16, 16, 16] 0
Conv2d-16 [1024, 72, 16, 16] 1,152
BatchNorm2d-17 [1024, 72, 16, 16] 144
ReLU-18 [1024, 72, 16, 16] 0
Conv2d-19 [1024, 72, 8, 8] 648
BatchNorm2d-20 [1024, 72, 8, 8] 144
ReLU-21 [1024, 72, 8, 8] 0
Conv2d-22 [1024, 24, 8, 8] 1,728
BatchNorm2d-23 [1024, 24, 8, 8] 48
InvertedResidual-24 [1024, 24, 8, 8] 0
Conv2d-25 [1024, 88, 8, 8] 2,112
BatchNorm2d-26 [1024, 88, 8, 8] 176
ReLU-27 [1024, 88, 8, 8] 0
Conv2d-28 [1024, 88, 8, 8] 792
BatchNorm2d-29 [1024, 88, 8, 8] 176
ReLU-30 [1024, 88, 8, 8] 0
Conv2d-31 [1024, 24, 8, 8] 2,112
BatchNorm2d-32 [1024, 24, 8, 8] 48
InvertedResidual-33 [1024, 24, 8, 8] 0
Conv2d-34 [1024, 96, 8, 8] 2,304
BatchNorm2d-35 [1024, 96, 8, 8] 192
Hardswish-36 [1024, 96, 8, 8] 0
Conv2d-37 [1024, 96, 4, 4] 2,400
BatchNorm2d-38 [1024, 96, 4, 4] 192
Hardswish-39 [1024, 96, 4, 4] 0
AdaptiveAvgPool2d-40 [1024, 96, 1, 1] 0
Conv2d-41 [1024, 24, 1, 1] 2,328
ReLU-42 [1024, 24, 1, 1] 0
Conv2d-43 [1024, 96, 1, 1] 2,400
Hardsigmoid-44 [1024, 96, 1, 1] 0
SqueezeExcitation-45 [1024, 96, 4, 4] 0
Conv2d-46 [1024, 40, 4, 4] 3,840
BatchNorm2d-47 [1024, 40, 4, 4] 80
InvertedResidual-48 [1024, 40, 4, 4] 0
Conv2d-49 [1024, 240, 4, 4] 9,600
BatchNorm2d-50 [1024, 240, 4, 4] 480
Hardswish-51 [1024, 240, 4, 4] 0
Conv2d-52 [1024, 240, 4, 4] 6,000
BatchNorm2d-53 [1024, 240, 4, 4] 480
Hardswish-54 [1024, 240, 4, 4] 0
AdaptiveAvgPool2d-55 [1024, 240, 1, 1] 0
Conv2d-56 [1024, 64, 1, 1] 15,424
ReLU-57 [1024, 64, 1, 1] 0
Conv2d-58 [1024, 240, 1, 1] 15,600
Hardsigmoid-59 [1024, 240, 1, 1] 0
SqueezeExcitation-60 [1024, 240, 4, 4] 0
Conv2d-61 [1024, 40, 4, 4] 9,600
BatchNorm2d-62 [1024, 40, 4, 4] 80
InvertedResidual-63 [1024, 40, 4, 4] 0
Conv2d-64 [1024, 240, 4, 4] 9,600
BatchNorm2d-65 [1024, 240, 4, 4] 480
Hardswish-66 [1024, 240, 4, 4] 0
Conv2d-67 [1024, 240, 4, 4] 6,000
BatchNorm2d-68 [1024, 240, 4, 4] 480
Hardswish-69 [1024, 240, 4, 4] 0
AdaptiveAvgPool2d-70 [1024, 240, 1, 1] 0
Conv2d-71 [1024, 64, 1, 1] 15,424
ReLU-72 [1024, 64, 1, 1] 0
Conv2d-73 [1024, 240, 1, 1] 15,600
Hardsigmoid-74 [1024, 240, 1, 1] 0
SqueezeExcitation-75 [1024, 240, 4, 4] 0
Conv2d-76 [1024, 40, 4, 4] 9,600
BatchNorm2d-77 [1024, 40, 4, 4] 80
InvertedResidual-78 [1024, 40, 4, 4] 0
Conv2d-79 [1024, 120, 4, 4] 4,800
BatchNorm2d-80 [1024, 120, 4, 4] 240
Hardswish-81 [1024, 120, 4, 4] 0
Conv2d-82 [1024, 120, 4, 4] 3,000
BatchNorm2d-83 [1024, 120, 4, 4] 240
Hardswish-84 [1024, 120, 4, 4] 0
AdaptiveAvgPool2d-85 [1024, 120, 1, 1] 0
Conv2d-86 [1024, 32, 1, 1] 3,872
ReLU-87 [1024, 32, 1, 1] 0
Conv2d-88 [1024, 120, 1, 1] 3,960
Hardsigmoid-89 [1024, 120, 1, 1] 0
SqueezeExcitation-90 [1024, 120, 4, 4] 0
Conv2d-91 [1024, 48, 4, 4] 5,760
BatchNorm2d-92 [1024, 48, 4, 4] 96
InvertedResidual-93 [1024, 48, 4, 4] 0
Conv2d-94 [1024, 144, 4, 4] 6,912
BatchNorm2d-95 [1024, 144, 4, 4] 288
Hardswish-96 [1024, 144, 4, 4] 0
Conv2d-97 [1024, 144, 4, 4] 3,600
BatchNorm2d-98 [1024, 144, 4, 4] 288
Hardswish-99 [1024, 144, 4, 4] 0
AdaptiveAvgPool2d-100 [1024, 144, 1, 1] 0
Conv2d-101 [1024, 40, 1, 1] 5,800
ReLU-102 [1024, 40, 1, 1] 0
Conv2d-103 [1024, 144, 1, 1] 5,904
Hardsigmoid-104 [1024, 144, 1, 1] 0
SqueezeExcitation-105 [1024, 144, 4, 4] 0
Conv2d-106 [1024, 48, 4, 4] 6,912
BatchNorm2d-107 [1024, 48, 4, 4] 96
InvertedResidual-108 [1024, 48, 4, 4] 0
Conv2d-109 [1024, 288, 4, 4] 13,824
BatchNorm2d-110 [1024, 288, 4, 4] 576
Hardswish-111 [1024, 288, 4, 4] 0
Conv2d-112 [1024, 288, 2, 2] 7,200
BatchNorm2d-113 [1024, 288, 2, 2] 576
Hardswish-114 [1024, 288, 2, 2] 0
AdaptiveAvgPool2d-115 [1024, 288, 1, 1] 0
Conv2d-116 [1024, 72, 1, 1] 20,808
ReLU-117 [1024, 72, 1, 1] 0
Conv2d-118 [1024, 288, 1, 1] 21,024
Hardsigmoid-119 [1024, 288, 1, 1] 0
SqueezeExcitation-120 [1024, 288, 2, 2] 0
Conv2d-121 [1024, 96, 2, 2] 27,648
BatchNorm2d-122 [1024, 96, 2, 2] 192
InvertedResidual-123 [1024, 96, 2, 2] 0
Conv2d-124 [1024, 576, 2, 2] 55,296
BatchNorm2d-125 [1024, 576, 2, 2] 1,152
Hardswish-126 [1024, 576, 2, 2] 0
Conv2d-127 [1024, 576, 2, 2] 14,400
BatchNorm2d-128 [1024, 576, 2, 2] 1,152
Hardswish-129 [1024, 576, 2, 2] 0
AdaptiveAvgPool2d-130 [1024, 576, 1, 1] 0
Conv2d-131 [1024, 144, 1, 1] 83,088
ReLU-132 [1024, 144, 1, 1] 0
Conv2d-133 [1024, 576, 1, 1] 83,520
Hardsigmoid-134 [1024, 576, 1, 1] 0
SqueezeExcitation-135 [1024, 576, 2, 2] 0
Conv2d-136 [1024, 96, 2, 2] 55,296
BatchNorm2d-137 [1024, 96, 2, 2] 192
InvertedResidual-138 [1024, 96, 2, 2] 0
Conv2d-139 [1024, 576, 2, 2] 55,296
BatchNorm2d-140 [1024, 576, 2, 2] 1,152
Hardswish-141 [1024, 576, 2, 2] 0
Conv2d-142 [1024, 576, 2, 2] 14,400
BatchNorm2d-143 [1024, 576, 2, 2] 1,152
Hardswish-144 [1024, 576, 2, 2] 0
AdaptiveAvgPool2d-145 [1024, 576, 1, 1] 0
Conv2d-146 [1024, 144, 1, 1] 83,088
ReLU-147 [1024, 144, 1, 1] 0
Conv2d-148 [1024, 576, 1, 1] 83,520
Hardsigmoid-149 [1024, 576, 1, 1] 0
SqueezeExcitation-150 [1024, 576, 2, 2] 0
Conv2d-151 [1024, 96, 2, 2] 55,296
BatchNorm2d-152 [1024, 96, 2, 2] 192
InvertedResidual-153 [1024, 96, 2, 2] 0
Conv2d-154 [1024, 576, 2, 2] 55,296
BatchNorm2d-155 [1024, 576, 2, 2] 1,152
Hardswish-156 [1024, 576, 2, 2] 0
AdaptiveAvgPool2d-157 [1024, 576, 1, 1] 0
Linear-158 [1024, 1024] 590,848
Hardswish-159 [1024, 1024] 0
Dropout-160 [1024, 1024] 0
Linear-161 [1024, 340] 348,500
================================================================
Total params: 1,866,356
Trainable params: 1,866,356
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 48.00
Forward/backward pass size (MB): 2979.22
Params size (MB): 7.12
Estimated Total Size (MB): 3034.34
----------------------------------------------------------------
🐶4.5 train-DDP.py
需要注意的是,train-DDP.py
中包含许多训练策略:
- DDP分布式训练(单机双卡);
- AMP混合精度训练;
- 学习率衰减;
- 早停;
- 断点继续训练。
# python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True
# Watch Training Log:tensorboard --logdir=tensorboard_dir
from tqdm import tqdm
import torch
import torch.nn.parallel
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import time
import os
import torch.optim
import torch.utils.data
import torch.nn as nn
from collections import OrderedDict
from model import CustomMobileNetV3
from getdata import MyData
from torch.cuda.amp import GradScaler
from option import get_args
opt = get_args()
dist.init_process_group(backend='nccl', init_method='env://')
os.makedirs(opt.checkpoints, exist_ok=True)
def train(gpu):
rank = dist.get_rank()
model = CustomMobileNetV3()
model.cuda(gpu)
criterion = nn.CrossEntropyLoss().to(gpu)
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
scaler = GradScaler(enabled=opt.use_mix_precision)
dataloaders = MyData()
train_loader = dataloaders['train']
test_loader = dataloaders['val']
if opt.use_lr_scheduler:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
start_time = time.time()
best_val_acc = 0.0
no_improve_epochs = 0
early_stopping_patience = 6 # Early Stopping Patience
"""breakckpt resume"""
if opt.resume:
checkpoint = torch.load(opt.resume_ckpt)
print('Loading checkpoint from:', opt.resume_ckpt)
new_state_dict = OrderedDict() # Create a new ordered dictionary and remove prefixes
for k, v in checkpoint['model'].items():
name = k[7:] # Remove 'module.' To match the original model definition
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False) # Load a new state dictionary
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] # Set the starting epoch
if opt.use_lr_scheduler:
scheduler.load_state_dict(checkpoint['scheduler'])
else:
start_epoch = 0
for epoch in range(start_epoch + 1, opt.epochs):
tqdm_trainloader = tqdm(train_loader, desc=f'Epoch {epoch}')
running_loss, running_correct_top1, running_correct_top3, running_correct_top5 = 0.0, 0.0, 0.0, 0.0
total_samples = 0
for i, (images, target) in enumerate(tqdm_trainloader if rank == 0 else train_loader, 0):
images = images.to(gpu)
target = target.to(gpu)
with torch.cuda.amp.autocast(enabled=opt.use_mix_precision):
output = model(images)
loss = criterion(output, target)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(output.data, 1)
running_correct_top1 += (predicted == target).sum().item()
_, predicted_top3 = torch.topk(output.data, 3, dim=1)
_, predicted_top5 = torch.topk(output.data, 5, dim=1)
running_correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()
running_correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()
total_samples += target.size(0)
state = {'epoch': epoch,
'model': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()}
if rank == 0:
current_lr = scheduler.get_last_lr()[0] if opt.use_lr_scheduler else opt.lr
print(f'[Epoch {epoch}] '
f'[Train Loss: {running_loss / len(train_loader.dataset):.6f}] '
f'[Train Top-1 Acc: {running_correct_top1 / len(train_loader.dataset):.6f}] '
f'[Train Top-3 Acc: {running_correct_top3 / len(train_loader.dataset):.6f}] '
f'[Train Top-5 Acc: {running_correct_top5 / len(train_loader.dataset):.6f}] '
f'[Learning Rate: {current_lr:.6f}] '
f'[Time: {time.time() - start_time:.6f} seconds]')
writer.add_scalar('Train/Loss', running_loss / len(train_loader.dataset), epoch)
writer.add_scalar('Train/Top-1 Accuracy', running_correct_top1 / len(train_loader.dataset), epoch)
writer.add_scalar('Train/Top-3 Accuracy', running_correct_top3 / len(train_loader.dataset), epoch)
writer.add_scalar('Train/Top-5 Accuracy', running_correct_top5 / len(train_loader.dataset), epoch)
writer.add_scalar('Train/Learning Rate', current_lr, epoch)
torch.save(state, f'{opt.checkpoints}model_epoch_{epoch}.pth')
# dist.barrier()
tqdm_trainloader.close()
if opt.use_lr_scheduler: # Learning-rate Scheduler
scheduler.step()
acc_top1 = valid(test_loader, model, epoch, gpu, rank)
if acc_top1 is not None:
if acc_top1 > best_val_acc:
best_val_acc = acc_top1
no_improve_epochs = 0
torch.save(state, f'{opt.checkpoints}/model_best.pth')
else:
no_improve_epochs += 1
if no_improve_epochs >= early_stopping_patience:
print(f'Early stopping triggered after {early_stopping_patience} epochs without improvement.')
break
else:
print("Warning: acc_top1 is None, skipping this epoch.")
dist.destroy_process_group()
def valid(val_loader, model, epoch, gpu, rank):
model.eval()
correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu)
with torch.no_grad():
tqdm_valloader = tqdm(val_loader, desc=f'Epoch {epoch}')
for i, (images, target) in enumerate(tqdm_valloader, 0) :
images = images.to(gpu)
target = target.to(gpu)
output = model(images)
total += target.size(0)
correct_top1 += (output.argmax(1) == target).type(torch.float).sum()
_, predicted_top3 = torch.topk(output, 3, dim=1)
_, predicted_top5 = torch.topk(output, 5, dim=1)
correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()
correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()
dist.reduce(total, 0, op=dist.ReduceOp.SUM) # Group communication reduce operation (change to allreduce if Gloo)
dist.reduce(correct_top1, 0, op=dist.ReduceOp.SUM)
dist.reduce(correct_top3, 0, op=dist.ReduceOp.SUM)
dist.reduce(correct_top5, 0, op=dist.ReduceOp.SUM)
if rank == 0:
print(f'[Epoch {epoch}] '
f'[Val Top-1 Acc: {correct_top1 / total:.6f}] '
f'[Val Top-3 Acc: {correct_top3 / total:.6f}] '
f'[Val Top-5 Acc: {correct_top5 / total:.6f}]')
writer.add_scalar('Validation/Top-1 Accuracy', correct_top1 / total, epoch)
writer.add_scalar('Validation/Top-3 Accuracy', correct_top3 / total, epoch)
writer.add_scalar('Validation/Top-5 Accuracy', correct_top5 / total, epoch)
return float(correct_top1 / total) # Return top 1 precision
tqdm_valloader.close()
def main():
train(opt.local_rank)
if __name__ == '__main__':
writer = SummaryWriter(log_dir=opt.tensorboard_dir)
main()
writer.close()
在终端使用下面命令可以启动多卡分布式训练:
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True
相关参数含义如下:
nproc_per_node
:显卡数量nnodes
:机器数量node_rank
:机器编号master_addr
:机器ip地址master_port
:机器端口
如果使用nohup
启动训练会存在一个bug
:
W0914 18:33:15.081479 140031432897728 torch/distributed/elastic/agent/server/api.py:741] Received Signals.SIGHUP death signal, shutting down workers
W0914 18:33:15.085310 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685186 closing signal SIGHUP
W0914 18:33:15.085644 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685192 closing signal SIGHUP
具体原因可以参考pytorch
官方的discuss
:DDP Error: torch.distributed.elastic.agent.server.api:Received 1 death signal, shutting down workers
我们可以使用tmux
解决这个问题。
- 安装
tmux
:sudo apt-get install tmux
- 新建会话:
tmux new -s train-DDP
(会话名称自定义) - 激活虚拟环境:
conda activate pytorch
(虚拟环境以实际需要为准) - 启动训练任务:
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True
tmux常用命令如下:
- 查看当前全部的
tmux
会话:tmux ls
- 新建会话:
tmux new -s 会话名字
- 重新进入会话:
tmux attach -t 会话名字
- kill会话:
tmux kill-session -t 会话名字
本文训练过程中的日志如下图所示:
模型在第11轮发生早停。
🐶4.6 model_transfer.py
代码作用是将pth
模型转为移动端的ptl
格式和onnx
格式,方便模型端侧部署。
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch
from model import CustomMobileNetV3
import onnx
from onnxsim import simplify
from torch.autograd import Variable
from option import get_args
opt = get_args()
model = CustomMobileNetV3()
model.load_state_dict(torch.load(f'{opt.checkpoints}model_best.pth', map_location='cpu')['model'])
model.eval()
print("Model loaded successfully.")
"""Save .pth format model"""
torch.save(model, f'{opt.checkpoints}/model.pth')
"""Save .ptl format model"""
example = torch.rand(1, 3, 64, 64)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(f'{opt.checkpoints}model.ptl')
"""Save .onnx format model"""
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, opt.loadsize, opt.loadsize))
torch.onnx.export(model, input, f'{opt.checkpoints}model.onnx', input_names=input_name, output_names=output_name, verbose=True)
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(f'{opt.checkpoints}model.onnx')), f'{opt.checkpoints}model.onnx') # Perform shape judgment
# simplified model
model_onnx = onnx.load(f'{opt.checkpoints}model.onnx')
model_simplified, check = simplify(model_onnx)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simplified, f'{opt.checkpoints}model_simplified.onnx')
🐶4.7 evaluate.py
代码定义了三个函数:
evaluate_image_single
:对单张图像进行预测evaluate_image_dir
:对文件夹图像进行预测evaluate_onnx_model
:onnx模型对图像进行预测
代码提供了多个可视化图像与评估指标。包括 混淆矩阵、F1score 等。
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch.nn.functional as F
import torch.utils.data
import onnxruntime
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
from tqdm import tqdm
from getdata import mean, std, class_names
from option import get_args
opt = get_args()
device = 'cuda:1'
"""Predicting a single image"""
def evaluate_image_single(img_path, transform_test, model, class_names, top_k):
image = Image.open(img_path).convert('RGB')
img = transform_test(image).to(device)
img = img.unsqueeze_(0)
out = model(img)
pred_softmax = F.softmax(out, dim=1)
top_n, top_n_indices = torch.topk(pred_softmax, top_k)
confs = top_n[0].cpu().detach().numpy().tolist()
class_names_top = [class_names[i] for i in top_n_indices[0]]
for i in range(top_k):
print(f'Pre: {class_names_top[i]} Conf: {confs[i]:.3f}')
confs_max = confs[0]
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.axis('off')
plt.title(f'Pre: {class_names_top[0]} Conf: {confs_max:.3f}')
plt.imshow(image)
sorted_pairs = sorted(zip(class_names_top, confs), key=lambda x: x[1], reverse=True)
sorted_class_names_top, sorted_confs = zip(*sorted_pairs)
plt.subplot(1, 2, 2)
bars = plt.bar(sorted_class_names_top, sorted_confs, color='lightcoral')
plt.xlabel('Class Names')
plt.ylabel('Confidence')
plt.title('Top 5 Predictions (Descending Order)')
plt.xticks(rotation=45)
plt.ylim(0, 1)
plt.tight_layout()
for bar, conf in zip(bars, sorted_confs):
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')
plt.savefig('predict_image_with_bars.jpg')
"""Predicting folder images"""
def evaluate_image_dir(model, dataloader, class_names):
model.eval()
all_preds = []
all_labels = []
correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device)
with torch.no_grad():
for images, labels in tqdm(dataloader, desc="Evaluating"):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
total += labels.size(0)
correct_top1 += (outputs.argmax(1) == labels).type(torch.float).sum()
_, predicted_top3 = torch.topk(outputs, 3, dim=1)
_, predicted_top5 = torch.topk(outputs, 5, dim=1)
correct_top3 += (predicted_top3[:, :3] == labels.unsqueeze(1).expand_as(predicted_top3)).sum().item()
correct_top5 += (predicted_top5[:, :5] == labels.unsqueeze(1).expand_as(predicted_top5)).sum().item()
_, preds = torch.max(outputs, 1)
all_preds.extend(preds)
all_labels.extend(labels)
all_preds = torch.tensor(all_preds)
all_labels = torch.tensor(all_labels)
top1 = correct_top1 / total
top3 = correct_top3 / total
top5 = correct_top5 / total
print(f"Top-1 Accuracy: {top1:.4f}")
print(f"Top-3 Accuracy: {top3:.4f}")
print(f"Top-5 Accuracy: {top5:.4f}")
accuracy = accuracy_score(all_labels.cpu().numpy(), all_preds.cpu().numpy())
precision = precision_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
recall = recall_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
f1 = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
cm = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())
report = classification_report(all_labels.cpu().numpy(), all_preds.cpu().numpy(), target_names=class_names)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(report)
plt.figure(figsize=(100, 100))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, annot_kws={"size": 8})
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.jpg')
"""Using .onnx model to predict images"""
def evaluate_onnx_model(img_path, data_transform, onnx_model_path, class_names, top_k=5):
ort_session = onnxruntime.InferenceSession(onnx_model_path)
img_pil = Image.open(img_path).convert('RGB')
input_img = data_transform(img_pil)
input_tensor = input_img.unsqueeze(0).numpy()
ort_inputs = {'input': input_tensor}
out = ort_session.run(['output'], ort_inputs)[0]
def softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
prob_dist = softmax(out)
result_dict = {label: float(prob_dist[0][i]) for i, label in enumerate(class_names)}
result_dict = dict(sorted(result_dict.items(), key=lambda item: item[1], reverse=True))
for key, value in list(result_dict.items())[:top_k]:
print(f'Pre: {key} Conf: {value:.3f}')
confs_max = list(result_dict.values())[0]
class_names_top = list(result_dict.keys())
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.axis('off')
plt.title(f'Pre: {class_names_top[0]} Conf: {confs_max:.3f}')
plt.imshow(img_pil)
plt.subplot(1, 2, 2)
bars = plt.bar(class_names_top[:top_k], list(result_dict.values())[:top_k], color='lightcoral')
plt.xlabel('Class Names')
plt.ylabel('Confidence')
plt.title('Top 5 Predictions (Descending Order)')
plt.xticks(rotation=45)
plt.ylim(0, 1)
plt.tight_layout()
for bar, conf in zip(bars, list(result_dict.values())[:top_k]):
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')
plt.savefig('predict_image_with_bars.jpg')
if __name__ == '__main__':
data_transform = transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)), transforms.ToTensor(),transforms.Normalize(mean, std)])
image_datasets = ImageFolder(opt.dataset_test, data_transform)
dataloaders = DataLoader(image_datasets, batch_size=512, shuffle=True)
ptl_model_path = opt.checkpoints + 'model.ptl'
pth_model_path = opt.checkpoints + 'model.pth'
onnx_model_path = opt.checkpoints + 'model.onnx'
ptl_model = torch.jit.load(ptl_model_path).to(device)
pth_model = torch.load(pth_model_path).to(device)
evaluate_image_single(opt.test_img_path, data_transform, pth_model, class_names, top_k=5) # Predicting a single image
# evaluate_image_dir(pth_model, dataloaders, class_names) # Predicting folder images
# evaluate_onnx_model(opt.test_img_path, data_transform, onnx_model_path, class_names, top_k=5) # Predicting a single image
使用evaluate_image_single
函数对datasets/test/zigzag/zigzag-4508464694951936.png
图片进行预测,结果如下:
使用evaluate_image_dir
函数对datasets/test
路径内的图像进行预测,结果如下:
Top-1 Accuracy: 0.6833
Top-3 Accuracy: 0.8521
Top-5 Accuracy: 0.8933
Accuracy: 0.6833
Precision: 0.6875
Recall: 0.6833
F1 Score: 0.6817
precision recall f1-score support
The Eiffel Tower 0.83 0.88 0.85 1000
The Great Wall of China 0.47 0.36 0.41 1000
The Mona Lisa 0.68 0.86 0.76 1000
airplane 0.83 0.74 0.78 1000
alarm clock 0.76 0.76 0.76 1000
ambulance 0.70 0.65 0.67 1000
angel 0.87 0.78 0.82 1000
animal migration 0.47 0.66 0.55 1000
ant 0.77 0.74 0.75 1000
anvil 0.80 0.66 0.72 1000
apple 0.82 0.85 0.83 1000
arm 0.74 0.69 0.71 1000
asparagus 0.54 0.44 0.48 1000
axe 0.69 0.67 0.68 1000
backpack 0.61 0.75 0.67 1000
banana 0.68 0.72 0.70 1000
bandage 0.83 0.71 0.77 1000
barn 0.66 0.68 0.67 1000
baseball 0.77 0.71 0.74 1000
baseball bat 0.75 0.73 0.74 1000
basket 0.71 0.62 0.66 1000
basketball 0.62 0.72 0.66 1000
bat 0.79 0.62 0.69 1000
bathtub 0.60 0.64 0.62 1000
beach 0.58 0.65 0.61 1000
bear 0.46 0.31 0.37 1000
beard 0.56 0.73 0.63 1000
bed 0.80 0.67 0.73 1000
bee 0.82 0.74 0.78 1000
belt 0.78 0.55 0.64 1000
bench 0.59 0.53 0.56 1000
bicycle 0.73 0.72 0.72 1000
binoculars 0.74 0.77 0.76 1000
bird 0.47 0.43 0.45 1000
birthday cake 0.52 0.64 0.57 1000
blackberry 0.46 0.42 0.44 1000
blueberry 0.58 0.47 0.52 1000
book 0.72 0.78 0.75 1000
boomerang 0.73 0.70 0.71 1000
bottlecap 0.58 0.54 0.56 1000
bowtie 0.87 0.86 0.86 1000
bracelet 0.68 0.60 0.64 1000
brain 0.59 0.60 0.59 1000
bread 0.54 0.63 0.58 1000
bridge 0.61 0.64 0.63 1000
broccoli 0.58 0.70 0.64 1000
broom 0.56 0.68 0.61 1000
bucket 0.62 0.66 0.64 1000
bulldozer 0.69 0.70 0.70 1000
bus 0.56 0.42 0.48 1000
bush 0.47 0.65 0.55 1000
butterfly 0.86 0.88 0.87 1000
cactus 0.69 0.87 0.77 1000
cake 0.53 0.42 0.47 1000
calculator 0.76 0.82 0.79 1000
calendar 0.54 0.50 0.52 1000
camel 0.82 0.84 0.83 1000
camera 0.87 0.74 0.80 1000
camouflage 0.23 0.43 0.30 1000
campfire 0.72 0.77 0.75 1000
candle 0.75 0.73 0.74 1000
cannon 0.77 0.69 0.72 1000
canoe 0.67 0.63 0.65 1000
car 0.65 0.63 0.64 1000
carrot 0.75 0.82 0.78 1000
castle 0.79 0.72 0.75 1000
cat 0.69 0.66 0.68 1000
ceiling fan 0.83 0.64 0.72 1000
cell phone 0.62 0.60 0.61 1000
cello 0.51 0.67 0.58 1000
chair 0.83 0.80 0.81 1000
chandelier 0.74 0.71 0.73 1000
church 0.72 0.67 0.69 1000
circle 0.53 0.86 0.66 1000
clarinet 0.53 0.63 0.58 1000
clock 0.86 0.77 0.82 1000
cloud 0.73 0.69 0.71 1000
coffee cup 0.67 0.43 0.52 1000
compass 0.69 0.78 0.73 1000
computer 0.79 0.62 0.69 1000
cookie 0.68 0.80 0.74 1000
cooler 0.47 0.33 0.38 1000
couch 0.76 0.82 0.79 1000
cow 0.70 0.57 0.63 1000
crab 0.70 0.72 0.71 1000
crayon 0.44 0.52 0.47 1000
crocodile 0.65 0.57 0.60 1000
crown 0.87 0.87 0.87 1000
cruise ship 0.76 0.69 0.73 1000
cup 0.43 0.50 0.47 1000
diamond 0.73 0.88 0.80 1000
dishwasher 0.56 0.47 0.51 1000
diving board 0.53 0.54 0.53 1000
dog 0.50 0.41 0.45 1000
dolphin 0.79 0.59 0.68 1000
donut 0.75 0.88 0.81 1000
door 0.69 0.72 0.70 1000
dragon 0.52 0.42 0.47 1000
dresser 0.75 0.65 0.70 1000
drill 0.78 0.71 0.75 1000
drums 0.71 0.68 0.70 1000
duck 0.68 0.49 0.57 1000
dumbbell 0.78 0.80 0.79 1000
ear 0.81 0.75 0.78 1000
elbow 0.74 0.62 0.68 1000
elephant 0.66 0.66 0.66 1000
envelope 0.87 0.94 0.90 1000
eraser 0.50 0.61 0.55 1000
eye 0.83 0.85 0.84 1000
eyeglasses 0.84 0.80 0.82 1000
face 0.62 0.64 0.63 1000
fan 0.76 0.60 0.67 1000
feather 0.58 0.60 0.59 1000
fence 0.67 0.71 0.69 1000
finger 0.70 0.63 0.67 1000
fire hydrant 0.56 0.64 0.60 1000
fireplace 0.74 0.67 0.71 1000
firetruck 0.71 0.50 0.59 1000
fish 0.89 0.85 0.87 1000
flamingo 0.69 0.75 0.72 1000
flashlight 0.80 0.82 0.81 1000
flip flops 0.64 0.75 0.69 1000
floor lamp 0.77 0.70 0.74 1000
flower 0.79 0.83 0.81 1000
flying saucer 0.65 0.64 0.64 1000
foot 0.68 0.66 0.67 1000
fork 0.81 0.79 0.80 1000
frog 0.46 0.47 0.47 1000
frying pan 0.78 0.76 0.77 1000
garden 0.59 0.63 0.61 1000
garden hose 0.42 0.28 0.33 1000
giraffe 0.87 0.80 0.84 1000
goatee 0.72 0.73 0.72 1000
golf club 0.60 0.62 0.61 1000
grapes 0.68 0.65 0.66 1000
grass 0.59 0.83 0.69 1000
guitar 0.68 0.50 0.58 1000
hamburger 0.66 0.83 0.73 1000
hammer 0.71 0.75 0.73 1000
hand 0.83 0.83 0.83 1000
harp 0.83 0.78 0.80 1000
hat 0.72 0.71 0.72 1000
headphones 0.92 0.91 0.92 1000
hedgehog 0.73 0.74 0.73 1000
helicopter 0.81 0.83 0.82 1000
helmet 0.63 0.66 0.64 1000
hexagon 0.70 0.73 0.72 1000
hockey puck 0.59 0.61 0.60 1000
hockey stick 0.59 0.54 0.56 1000
horse 0.53 0.85 0.65 1000
hospital 0.80 0.68 0.74 1000
hot air balloon 0.79 0.72 0.75 1000
hot dog 0.60 0.63 0.62 1000
hot tub 0.58 0.51 0.54 1000
hourglass 0.86 0.87 0.87 1000
house 0.77 0.77 0.77 1000
house plant 0.85 0.82 0.83 1000
hurricane 0.39 0.45 0.42 1000
ice cream 0.82 0.85 0.84 1000
jacket 0.75 0.72 0.74 1000
jail 0.71 0.72 0.71 1000
kangaroo 0.73 0.71 0.72 1000
key 0.71 0.76 0.74 1000
keyboard 0.50 0.48 0.49 1000
knee 0.63 0.68 0.65 1000
ladder 0.88 0.91 0.89 1000
lantern 0.70 0.53 0.60 1000
laptop 0.63 0.80 0.71 1000
leaf 0.73 0.71 0.72 1000
leg 0.58 0.50 0.54 1000
light bulb 0.69 0.79 0.73 1000
lighthouse 0.71 0.74 0.72 1000
lightning 0.76 0.69 0.72 1000
line 0.55 0.82 0.66 1000
lion 0.70 0.76 0.73 1000
lipstick 0.59 0.69 0.63 1000
lobster 0.61 0.47 0.53 1000
lollipop 0.76 0.85 0.80 1000
mailbox 0.75 0.66 0.70 1000
map 0.65 0.73 0.68 1000
marker 0.39 0.16 0.23 1000
matches 0.52 0.47 0.49 1000
megaphone 0.80 0.70 0.75 1000
mermaid 0.76 0.84 0.80 1000
microphone 0.64 0.73 0.68 1000
microwave 0.79 0.75 0.77 1000
monkey 0.59 0.56 0.57 1000
moon 0.69 0.60 0.64 1000
mosquito 0.48 0.55 0.51 1000
motorbike 0.64 0.62 0.63 1000
mountain 0.74 0.80 0.77 1000
mouse 0.53 0.46 0.49 1000
moustache 0.75 0.72 0.73 1000
mouth 0.72 0.76 0.74 1000
mug 0.54 0.65 0.59 1000
mushroom 0.66 0.76 0.70 1000
nail 0.58 0.66 0.62 1000
necklace 0.75 0.63 0.68 1000
nose 0.69 0.75 0.72 1000
ocean 0.54 0.54 0.54 1000
octagon 0.71 0.62 0.66 1000
octopus 0.89 0.83 0.86 1000
onion 0.75 0.68 0.71 1000
oven 0.50 0.39 0.44 1000
owl 0.68 0.65 0.67 1000
paint can 0.51 0.49 0.50 1000
paintbrush 0.58 0.63 0.61 1000
palm tree 0.73 0.83 0.78 1000
panda 0.66 0.62 0.64 1000
pants 0.75 0.68 0.71 1000
paper clip 0.75 0.78 0.76 1000
parachute 0.81 0.79 0.80 1000
parrot 0.54 0.59 0.56 1000
passport 0.60 0.55 0.58 1000
peanut 0.70 0.73 0.71 1000
pear 0.72 0.80 0.76 1000
peas 0.70 0.56 0.62 1000
pencil 0.58 0.60 0.59 1000
penguin 0.69 0.78 0.73 1000
piano 0.65 0.66 0.65 1000
pickup truck 0.60 0.64 0.62 1000
picture frame 0.68 0.89 0.77 1000
pig 0.77 0.56 0.65 1000
pillow 0.60 0.58 0.59 1000
pineapple 0.80 0.85 0.82 1000
pizza 0.65 0.77 0.70 1000
pliers 0.69 0.55 0.61 1000
police car 0.67 0.68 0.67 1000
pond 0.40 0.47 0.43 1000
pool 0.51 0.23 0.32 1000
popsicle 0.70 0.79 0.75 1000
postcard 0.74 0.58 0.65 1000
potato 0.54 0.40 0.46 1000
power outlet 0.61 0.72 0.66 1000
purse 0.64 0.69 0.66 1000
rabbit 0.66 0.80 0.72 1000
raccoon 0.43 0.44 0.44 1000
radio 0.71 0.59 0.64 1000
rain 0.77 0.90 0.83 1000
rainbow 0.79 0.92 0.85 1000
rake 0.69 0.67 0.68 1000
remote control 0.67 0.68 0.67 1000
rhinoceros 0.65 0.75 0.69 1000
river 0.66 0.61 0.64 1000
roller coaster 0.70 0.52 0.60 1000
rollerskates 0.86 0.83 0.84 1000
sailboat 0.84 0.87 0.86 1000
sandwich 0.50 0.68 0.57 1000
saw 0.81 0.83 0.82 1000
saxophone 0.79 0.77 0.78 1000
school bus 0.51 0.44 0.47 1000
scissors 0.80 0.84 0.82 1000
scorpion 0.70 0.76 0.73 1000
screwdriver 0.58 0.62 0.60 1000
sea turtle 0.79 0.73 0.76 1000
see saw 0.85 0.79 0.82 1000
shark 0.72 0.72 0.72 1000
sheep 0.75 0.80 0.77 1000
shoe 0.73 0.75 0.74 1000
shorts 0.67 0.76 0.71 1000
shovel 0.62 0.73 0.67 1000
sink 0.62 0.76 0.68 1000
skateboard 0.83 0.85 0.84 1000
skull 0.86 0.83 0.85 1000
skyscraper 0.65 0.56 0.60 1000
sleeping bag 0.55 0.59 0.57 1000
smiley face 0.74 0.80 0.77 1000
snail 0.79 0.90 0.84 1000
snake 0.65 0.66 0.65 1000
snorkel 0.79 0.73 0.76 1000
snowflake 0.79 0.84 0.81 1000
snowman 0.83 0.90 0.86 1000
soccer ball 0.69 0.70 0.69 1000
sock 0.77 0.75 0.76 1000
speedboat 0.65 0.65 0.65 1000
spider 0.72 0.79 0.76 1000
spoon 0.69 0.57 0.63 1000
spreadsheet 0.67 0.62 0.65 1000
square 0.52 0.84 0.65 1000
squiggle 0.41 0.40 0.40 1000
squirrel 0.71 0.74 0.72 1000
stairs 0.90 0.91 0.90 1000
star 0.93 0.91 0.92 1000
steak 0.53 0.46 0.49 1000
stereo 0.61 0.68 0.64 1000
stethoscope 0.87 0.75 0.81 1000
stitches 0.71 0.79 0.75 1000
stop sign 0.86 0.88 0.87 1000
stove 0.71 0.66 0.69 1000
strawberry 0.80 0.80 0.80 1000
streetlight 0.75 0.71 0.73 1000
string bean 0.51 0.39 0.44 1000
submarine 0.83 0.67 0.74 1000
suitcase 0.75 0.57 0.64 1000
sun 0.87 0.88 0.87 1000
swan 0.69 0.67 0.68 1000
sweater 0.68 0.65 0.67 1000
swing set 0.89 0.90 0.89 1000
sword 0.85 0.81 0.83 1000
t-shirt 0.80 0.78 0.79 1000
table 0.73 0.76 0.74 1000
teapot 0.82 0.77 0.80 1000
teddy-bear 0.66 0.74 0.70 1000
telephone 0.67 0.54 0.60 1000
television 0.88 0.85 0.86 1000
tennis racquet 0.86 0.74 0.80 1000
tent 0.80 0.77 0.78 1000
tiger 0.53 0.47 0.50 1000
toaster 0.59 0.70 0.64 1000
toe 0.67 0.63 0.65 1000
toilet 0.74 0.80 0.77 1000
tooth 0.72 0.74 0.73 1000
toothbrush 0.74 0.76 0.75 1000
toothpaste 0.54 0.56 0.55 1000
tornado 0.63 0.69 0.66 1000
tractor 0.65 0.71 0.68 1000
traffic light 0.84 0.84 0.84 1000
train 0.61 0.74 0.67 1000
tree 0.72 0.75 0.73 1000
triangle 0.87 0.93 0.90 1000
trombone 0.58 0.48 0.53 1000
truck 0.50 0.41 0.45 1000
trumpet 0.65 0.49 0.56 1000
umbrella 0.91 0.86 0.88 1000
underwear 0.83 0.64 0.72 1000
van 0.46 0.58 0.51 1000
vase 0.82 0.67 0.74 1000
violin 0.52 0.52 0.52 1000
washing machine 0.74 0.78 0.76 1000
watermelon 0.56 0.66 0.61 1000
waterslide 0.57 0.70 0.63 1000
whale 0.71 0.74 0.72 1000
wheel 0.82 0.50 0.62 1000
windmill 0.82 0.77 0.79 1000
wine bottle 0.77 0.81 0.79 1000
wine glass 0.86 0.85 0.86 1000
wristwatch 0.72 0.74 0.73 1000
yoga 0.60 0.57 0.58 1000
zebra 0.73 0.66 0.69 1000
zigzag 0.73 0.75 0.74 1000
accuracy 0.68 340000
macro avg 0.69 0.68 0.68 340000
weighted avg 0.69 0.68 0.68 340000