【深度学习实战—11】:基于Pytorch实现谷歌QuickDraw数据集的下载、解析、格式转换、DDP分布式训练、测试

news2024/11/13 8:49:50

✨博客主页:王乐予🎈
✨年轻人要:Living for the moment(活在当下)!💪
🏆推荐专栏:【图像处理】【千锤百炼Python】【深度学习】【排序算法】

目录

  • 😺〇、仓库源码
  • 😺一、数据集介绍
    • 🐶1.1 GitHub原始数据集
    • 🐶1.2 GitHub预处理后的数据集
      • 🦄1.2.1 简化的绘图文件(.ndjson)
      • 🦄1.2.2 二进制文件(.bin)
      • 🦄1.2.3 Numpy位图(.npy)
    • 🐶1.3 Kaggle数据集
  • 😺二、数据集准备
  • 😺三、获取png格式图片
  • 😺四、训练过程
    • 🐶4.1 split_datasets.py
    • 🐶4.2 option.py
    • 🐶4.3 getdata.py
    • 🐶4.4 model.py
    • 🐶4.5 train-DDP.py
    • 🐶4.6 model_transfer.py
    • 🐶4.7 evaluate.py

😺〇、仓库源码

本文所有代码存放在GitHub仓库中QuickDraw-DDP:欢迎forkstar

😺一、数据集介绍

在这里插入图片描述
Quick Draw 数据集是 345 个类别的 5000 万张图纸的集合,由游戏 Quick, Draw!的玩家贡献。这些图画被捕获为带时间戳的矢量,并标记有元数据,包括要求玩家绘制的内容以及玩家所在的国家/地区。

GitHub数据集地址: 📎The Quick, Draw! Dataset

Kaggle数据集地址:📎Quick, Draw! Doodle Recognition Challenge

Github中提供了两种类型的数据集,分别是 原始数据集预处理后的数据集
Google Cloud提供了数据集下载链接:quickdraw_dataset
在这里插入图片描述

🐶1.1 GitHub原始数据集

原始数据以按类别分隔的 ndjson 文件的形式提供,格式如下:

类型说明
key_id64位无符号整型所有图形的唯一标识符
word字符串类别
recognized布尔值该类别是否被游戏识别
timestamp日期时间绘制时间
countrycode字符串玩家所在位置的双字母国家/地区代码 (ISO 3166-1 alpha-2)
drawing字符串一个矢量绘制的 JSON 数组

每行包含一个绘图数据,下面是单个绘图的示例:

  { 
    "key_id":"5891796615823360",
    "word":"nose",
    "countrycode":"AE",
    "timestamp":"2017-03-01 20:41:36.70725 UTC",
    "recognized":true,
    "drawing":[[[129,128,129,129,130,130,131,132,132,133,133,133,133,...]]]
  }

drawing字段格式如下:

[ 
  [  // First stroke 
    [x0, x1, x2, x3, ...],
    [y0, y1, y2, y3, ...],
    [t0, t1, t2, t3, ...]
  ],
  [  // Second stroke
    [x0, x1, x2, x3, ...],
    [y0, y1, y2, y3, ...],
    [t0, t1, t2, t3, ...]
  ],
  ... // Additional strokes
]

其中xy是像素坐标,t是自第一个点以来的时间(以毫秒为单位)。由于用于显示和输入的设备不同,原始绘图可能具有截然不同的边界框和点数。

🐶1.2 GitHub预处理后的数据集

🦄1.2.1 简化的绘图文件(.ndjson)

简化了向量,删除了时序信息,并将数据定位和缩放为256x256区域。数据以ndjson格式导出,其元数据与raw格式相同。简化过程是:

  1. 将绘图与左上角对齐,最小值为 0。
  2. 统一缩放绘图,最大值为 255。
  3. 以 1 像素的间距对所有描边重新取样。
  4. 使用 epsilon 值为 2.0 的Ramer-Douglas-Peucker 算法简化所有笔画。

读取ndjson文件的代码如下:

# read_ndjson.py
import json

with open('aircraft carrier.ndjson', 'r') as file:
    for line in file:
        data = json.loads(line)
        key_id = data['key_id']
        drawing = data['drawing']
        # ……

读取aircraft carrier.ndjsondebug之后的输出结果如下图所示。可以看到第一行数据包含8个笔触。
在这里插入图片描述

🦄1.2.2 二进制文件(.bin)

简化的图纸和元数据也以自定义二进制格式提供,以实现高效的压缩和加载。

读取bin文件的代码如下:

# read_bin.py
import struct
from struct import unpack

def unpack_drawing(file_handle):
    key_id, = unpack('Q', file_handle.read(8))
    country_code, = unpack('2s', file_handle.read(2))
    recognized, = unpack('b', file_handle.read(1))
    timestamp, = unpack('I', file_handle.read(4))
    n_strokes, = unpack('H', file_handle.read(2))
    image = []
    for i in range(n_strokes):
        n_points, = unpack('H', file_handle.read(2))
        fmt = str(n_points) + 'B'
        x = unpack(fmt, file_handle.read(n_points))
        y = unpack(fmt, file_handle.read(n_points))
        image.append((x, y))

    return {
        'key_id': key_id,
        'country_code': country_code,
        'recognized': recognized,
        'timestamp': timestamp,
        'image': image
    }

def unpack_drawings(filename):
    with open(filename, 'rb') as f:
        while True:
            try:
                yield unpack_drawing(f)
            except struct.error:
                break

for drawing in unpack_drawings('nose.bin'):
    # do something with the drawing
    print(drawing['country_code'])

🦄1.2.3 Numpy位图(.npy)

所有简化的绘图都已渲染为numpy格式的28x28灰度位图。这些图像是根据简化的数据生成的,但与绘图边界框的中心对齐,而不是与左上角对齐。

读取npy文件的代码如下:

# read_npy.py
import numpy as np

data_path = 'aircraft_carrier.npy'
 
data = np.load(data_path)
print(data)

🐶1.3 Kaggle数据集

在Kaggle竞赛中,使用的数据集为340个类别。数据格式统一为csv表格数据。数据集中有5个文件:

  • sample_submission.csv - 正确格式的样本提交文件
  • test_raw.csv - 矢量格式的测试数据raw
  • test_simplified.csv - 矢量格式的测试数据simplified
  • train_raw.zip - 向量格式的训练数据;每个单词一个 CSV 文件raw
  • train_simplified.zip - 向量格式的训练数据;每个单词一个 CSV 文件simplified

注:csv文件的列titlendjson文件的键名一致。

😺二、数据集准备

本文将使用kaggle提供的train_simplified数据集。案例流程包含:

  1. 将所有类的csv格式文件保存为png图片格式;
  2. 对340个类别的png格式图片各抽取10000张用作后续实践;
  3. 对每个类别的10000张数据进行8:1:1的训练集、验证集、测试集的划分;
  4. 训练模型;
  5. 模型评估。

😺三、获取png格式图片

使用下面脚本可以将csv数据转为png图片格式保存。

# csv2png.py
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy import interpolate, misc
import matplotlib
matplotlib.use('Agg')

input_dir = 'kaggle/train_simplified'
output_base_dir = 'datasets256'

os.makedirs(output_base_dir, exist_ok=True)

csv_files = [f for f in os.listdir(input_dir) if f.endswith('.csv')]    # Retrieve all CSV files from the folder

skipped_files = []  # Record skipped files

for csv_file in csv_files:
    
    csv_file_path = os.path.join(input_dir, csv_file)   # Build a complete file path
    output_dir = os.path.join(output_base_dir, os.path.splitext(csv_file)[0])   # Build output directory
   
    if os.path.exists(output_dir):      # Check if the output directory exists
        skipped_files.append(csv_file)
        print(f'The directory already exists, skip file: {csv_file}')
        continue
   
    os.makedirs(output_dir, exist_ok=True)
    
    data = pd.read_csv(csv_file_path)       # Read CSV file
    
    for index, row in data.iterrows():  # Traverse each row of data
        drawing = eval(row['drawing'])
        key_id = row['key_id']
        word = row['word']
        
        img = np.zeros((256, 256))      # Initialize image
        fig = plt.figure(figsize=(256/96, 256/96), dpi=96)
        
        for stroke in drawing:      # Draw each stroke
            stroke_x = stroke[0]
            stroke_y = stroke[1]
            x = np.array(stroke_x)
            y = np.array(stroke_y)
            np.interp((x + y) / 2, x, y)
            plt.plot(x, y, 'k')
        
        ax = plt.gca()
        ax.xaxis.set_ticks_position('top')
        ax.invert_yaxis()
        plt.axis('off')
        plt.savefig(os.path.join(output_dir, f'{word}-{key_id}.png'))
        plt.close(fig)
        print(f'Conversion completed: {csv_file} the {index:06d}image')
        
print("The skipped files are:")
for file in skipped_files:
    print(file)

需要注意的是:绘图数据有5000万左右,处理时间非常久,建议多开几个脚本运行(PS:代码中添加了文件夹是否存在的判断语句,不用担心会重复写入)。也可以使用joblib库多线程加速(玩不好容易宕机,不建议)。

相关文件存储空间大小如下:

  • GitHub 预处理后的ndjson文件有23G
  • Kaggletrain_raw.zip文件有206G
  • Kaggletrain_simplified.zip文件有23G
  • Kaggletrain_simplified转为256*256大小的图片有470G

如果磁盘空间不足,进行png转化时可以选择128128大小或者6464大小。也可以保存单通道图像。

建议处理完毕之后使用下面的脚本检查一下有没有没处理的类别:

# check_class_num.py
import os

folder = 'datasets256'

subfolders = [f.path for f in os.scandir(folder) if f.is_dir()]

for subfolder in subfolders:    # Traverse each subfolders
    folder_name = os.path.basename(subfolder)   # Get the name of the subfolders
    
    files = [f for f in os.scandir(subfolder) if f.is_file()]   # Retrieve all files in the subfolders
    
    image_count = sum(1 for f in files if f.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')))   # Calculate the number of images

    if image_count == 0:        # If the number of images is 0, print out the names of the subfolders and delete them
        
        print(f"There are no images in the subfolders '{folder_name}', deleting them...")
        os.rmdir(subfolder)
        
        print(f"subfolders '{folder_name}' deleted")
    else:
        print(f"Number of images in subfolders: '{folder_name}' : {image_count}")

如果检查到有空文件夹,需要再运行csv2png.py的代码。

😺四、训练过程

🐶4.1 split_datasets.py

首先要划分数据集,原始数据为png图片格式数据集。

import os
import shutil
import random


original_dataset_path = 'datasets256'     # Original dataset path
new_dataset_path = 'datasets'                       # Divide the dataset path

train_path = os.path.join(new_dataset_path, 'train')
val_path = os.path.join(new_dataset_path, 'val')
test_path = os.path.join(new_dataset_path, 'test')

if not os.path.exists(train_path):
    os.makedirs(train_path)

if not os.path.exists(val_path):
    os.makedirs(val_path)

if not os.path.exists(test_path):
    os.makedirs(test_path)

classes = os.listdir(original_dataset_path)     # Get all categories

random.seed(42)

for class_name in classes:      # Traverse each category
    
    src_folder = os.path.join(original_dataset_path, class_name)    # Source folder path
    
    # Check if the folder for this category already exists under train, val, and test
    train_folder = os.path.join(train_path, class_name)
    val_folder = os.path.join(val_path, class_name)
    test_folder = os.path.join(test_path, class_name)

    # If the train, val, and test folders already exist, skip the folder creation section
    if os.path.exists(train_folder) and os.path.exists(val_folder) and os.path.exists(test_folder):
        # Check if the folder is empty
        if os.listdir(train_folder) and os.listdir(val_folder) and os.listdir(test_folder):
            print(f"Category {class_name} already exists and is not empty, skip processing.")
            continue

    # create folder
    if not os.path.exists(train_folder):
        os.makedirs(train_folder)

    if not os.path.exists(val_folder):
        os.makedirs(val_folder)

    if not os.path.exists(test_folder):
        os.makedirs(test_folder)

    
    files = os.listdir(src_folder)      # Retrieve all file names under this category
    files = files[:10000]       # Only retrieve the first 10000 files
    random.shuffle(files)       # Shuffle file list

    total_files = len(files)
    train_split_index = int(total_files * 0.8)
    val_split_index = int(total_files * 0.9)

    train_files = files[:train_split_index]
    val_files = files[train_split_index:val_split_index]
    test_files = files[val_split_index:]

    for file in train_files:
        src_file = os.path.join(src_folder, file)
        dst_file = os.path.join(train_folder, file)
        shutil.copy(src_file, dst_file)

    for file in val_files:
        src_file = os.path.join(src_folder, file)
        dst_file = os.path.join(val_folder, file)
        shutil.copy(src_file, dst_file)

    for file in test_files:
        src_file = os.path.join(src_folder, file)
        dst_file = os.path.join(test_folder, file)
        shutil.copy(src_file, dst_file)

print("Dataset partitioning completed!")

代码运行完毕之后,datasets目录下面会出现三个文件夹,分别是trainvaltest

🐶4.2 option.py

定义后续我们需要的一些参数。

import argparse


def get_args():
    parser = argparse.ArgumentParser(description='all argument')
    parser.add_argument('--num_classes', type=int, default=340, help='image num classes')
    parser.add_argument('--loadsize', type=int, default=64, help='image size')
    parser.add_argument('--epochs', type=int, default=100, help='all epochs')
    parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
    parser.add_argument('--lr', type=float, default=0.001, help='init lr')
    parser.add_argument('--use_lr_scheduler', type=bool, default=True, help='use lr scheduler')
    parser.add_argument('--dataset_train', type=str, default='./datasets/train', help='train path')
    parser.add_argument('--dataset_val', type=str, default="./datasets/val", help='val path')
    parser.add_argument('--dataset_test', type=str, default="./datasets/test", help='test path')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='ckpt path')
    parser.add_argument('--tensorboard_dir', type=str, default='./tensorboard_dir', help='log path')
    parser.add_argument('--resume', type=bool, default=False, help='continue training')
    parser.add_argument('--resume_ckpt', type=str, default='./checkpoints/model_best.pth', help='choose breakpoint ckpt')
    parser.add_argument('--local-rank', type=int, default=-1, help='local rank')
    parser.add_argument('--use_mix_precision', type=bool, default=False, help='use mix pretrain')
    parser.add_argument('--test_img_path', type=str, default='datasets/test/zigzag/zigzag-4508464694951936.png', help='choose test image')
    parser.add_argument('--test_dir_path', type=str, default='./datasets/test', help='choose test path')
    
    return parser.parse_args()

由于后续将使用DDP单机多卡以及AMP策略进行训练,因此额外加入了local-rankuse_mix_precision参数。

🐶4.3 getdata.py

接下来定义数据管道。

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from option import get_args
opt = get_args()

mean = [0.9367, 0.9404, 0.9405]
std = [0.1971, 0.1970, 0.1972]
def data_augmentation():
    data_transform = {
        'train': transforms.Compose([
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),  # HWC -> CHW
            transforms.Normalize(mean, std)
        ]),
        'val': transforms.Compose([
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    }
    return data_transform


def MyData():
    data_transform = data_augmentation()
    image_datasets = {
        'train': ImageFolder(opt.dataset_train, data_transform['train']),
        'val': ImageFolder(opt.dataset_val, data_transform['val']),
    }
    data_sampler = {
        'train': torch.utils.data.distributed.DistributedSampler(image_datasets['train']),
        'val': torch.utils.data.distributed.DistributedSampler(image_datasets['val']),
        }
    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['train']),
        'val': DataLoader(image_datasets['val'], batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=data_sampler['val'])
    }
    return dataloaders

class_names =[
    'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'airplane', 'alarm clock', 'ambulance', 'angel', 
    'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 
    'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 
    'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 
    'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 
    'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 
    'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 
    'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 
    'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 
    'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 
    'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 
    'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 
    'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 
    'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 
    'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 
    'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 
    'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 
    'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 
    'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 
    'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 
    'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 
    'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 
    'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 
    'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 
    'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 
    'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 
    'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 
    'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 
    'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 
    'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'tiger', 
    'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 
    'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 
    'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 
    'zigzag'
]

if __name__ == '__main__':
    mena_std_transform = transforms.Compose([transforms.ToTensor()])
    dataset = ImageFolder(opt.dataset_val, transform=mena_std_transform)
    print(dataset.class_to_idx)		# Index for each category

🐶4.4 model.py

定义模型,这里使用mobilenet的small版本。需要将模型的classifier层的输出改为类别数量。
可以使用更多优质的模型对数据集进行训练,例如shufflenetsqueezenet等。

import torch.nn as nn
from torchvision.models import mobilenet_v3_small
from torchsummary import summary
from option import get_args
opt = get_args()


def CustomMobileNetV3():
    model = mobilenet_v3_small(weights='MobileNet_V3_Small_Weights.IMAGENET1K_V1')
    model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, opt.num_classes)
    return model


if __name__ == '__main__':
    model = CustomMobileNetV3()
    print(model)
    print(summary(model.to(opt.device), (3, opt.loadsize, opt.loadsize), opt.batch_size))

模型结构如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [1024, 16, 32, 32]             432
       BatchNorm2d-2         [1024, 16, 32, 32]              32
         Hardswish-3         [1024, 16, 32, 32]               0
            Conv2d-4         [1024, 16, 16, 16]             144
       BatchNorm2d-5         [1024, 16, 16, 16]              32
              ReLU-6         [1024, 16, 16, 16]               0
 AdaptiveAvgPool2d-7           [1024, 16, 1, 1]               0
            Conv2d-8            [1024, 8, 1, 1]             136
              ReLU-9            [1024, 8, 1, 1]               0
           Conv2d-10           [1024, 16, 1, 1]             144
      Hardsigmoid-11           [1024, 16, 1, 1]               0
SqueezeExcitation-12         [1024, 16, 16, 16]               0
           Conv2d-13         [1024, 16, 16, 16]             256
      BatchNorm2d-14         [1024, 16, 16, 16]              32
 InvertedResidual-15         [1024, 16, 16, 16]               0
           Conv2d-16         [1024, 72, 16, 16]           1,152
      BatchNorm2d-17         [1024, 72, 16, 16]             144
             ReLU-18         [1024, 72, 16, 16]               0
           Conv2d-19           [1024, 72, 8, 8]             648
      BatchNorm2d-20           [1024, 72, 8, 8]             144
             ReLU-21           [1024, 72, 8, 8]               0
           Conv2d-22           [1024, 24, 8, 8]           1,728
      BatchNorm2d-23           [1024, 24, 8, 8]              48
 InvertedResidual-24           [1024, 24, 8, 8]               0
           Conv2d-25           [1024, 88, 8, 8]           2,112
      BatchNorm2d-26           [1024, 88, 8, 8]             176
             ReLU-27           [1024, 88, 8, 8]               0
           Conv2d-28           [1024, 88, 8, 8]             792
      BatchNorm2d-29           [1024, 88, 8, 8]             176
             ReLU-30           [1024, 88, 8, 8]               0
           Conv2d-31           [1024, 24, 8, 8]           2,112
      BatchNorm2d-32           [1024, 24, 8, 8]              48
 InvertedResidual-33           [1024, 24, 8, 8]               0
           Conv2d-34           [1024, 96, 8, 8]           2,304
      BatchNorm2d-35           [1024, 96, 8, 8]             192
        Hardswish-36           [1024, 96, 8, 8]               0
           Conv2d-37           [1024, 96, 4, 4]           2,400
      BatchNorm2d-38           [1024, 96, 4, 4]             192
        Hardswish-39           [1024, 96, 4, 4]               0
AdaptiveAvgPool2d-40           [1024, 96, 1, 1]               0
           Conv2d-41           [1024, 24, 1, 1]           2,328
             ReLU-42           [1024, 24, 1, 1]               0
           Conv2d-43           [1024, 96, 1, 1]           2,400
      Hardsigmoid-44           [1024, 96, 1, 1]               0
SqueezeExcitation-45           [1024, 96, 4, 4]               0
           Conv2d-46           [1024, 40, 4, 4]           3,840
      BatchNorm2d-47           [1024, 40, 4, 4]              80
 InvertedResidual-48           [1024, 40, 4, 4]               0
           Conv2d-49          [1024, 240, 4, 4]           9,600
      BatchNorm2d-50          [1024, 240, 4, 4]             480
        Hardswish-51          [1024, 240, 4, 4]               0
           Conv2d-52          [1024, 240, 4, 4]           6,000
      BatchNorm2d-53          [1024, 240, 4, 4]             480
        Hardswish-54          [1024, 240, 4, 4]               0
AdaptiveAvgPool2d-55          [1024, 240, 1, 1]               0
           Conv2d-56           [1024, 64, 1, 1]          15,424
             ReLU-57           [1024, 64, 1, 1]               0
           Conv2d-58          [1024, 240, 1, 1]          15,600
      Hardsigmoid-59          [1024, 240, 1, 1]               0
SqueezeExcitation-60          [1024, 240, 4, 4]               0
           Conv2d-61           [1024, 40, 4, 4]           9,600
      BatchNorm2d-62           [1024, 40, 4, 4]              80
 InvertedResidual-63           [1024, 40, 4, 4]               0
           Conv2d-64          [1024, 240, 4, 4]           9,600
      BatchNorm2d-65          [1024, 240, 4, 4]             480
        Hardswish-66          [1024, 240, 4, 4]               0
           Conv2d-67          [1024, 240, 4, 4]           6,000
      BatchNorm2d-68          [1024, 240, 4, 4]             480
        Hardswish-69          [1024, 240, 4, 4]               0
AdaptiveAvgPool2d-70          [1024, 240, 1, 1]               0
           Conv2d-71           [1024, 64, 1, 1]          15,424
             ReLU-72           [1024, 64, 1, 1]               0
           Conv2d-73          [1024, 240, 1, 1]          15,600
      Hardsigmoid-74          [1024, 240, 1, 1]               0
SqueezeExcitation-75          [1024, 240, 4, 4]               0
           Conv2d-76           [1024, 40, 4, 4]           9,600
      BatchNorm2d-77           [1024, 40, 4, 4]              80
 InvertedResidual-78           [1024, 40, 4, 4]               0
           Conv2d-79          [1024, 120, 4, 4]           4,800
      BatchNorm2d-80          [1024, 120, 4, 4]             240
        Hardswish-81          [1024, 120, 4, 4]               0
           Conv2d-82          [1024, 120, 4, 4]           3,000
      BatchNorm2d-83          [1024, 120, 4, 4]             240
        Hardswish-84          [1024, 120, 4, 4]               0
AdaptiveAvgPool2d-85          [1024, 120, 1, 1]               0
           Conv2d-86           [1024, 32, 1, 1]           3,872
             ReLU-87           [1024, 32, 1, 1]               0
           Conv2d-88          [1024, 120, 1, 1]           3,960
      Hardsigmoid-89          [1024, 120, 1, 1]               0
SqueezeExcitation-90          [1024, 120, 4, 4]               0
           Conv2d-91           [1024, 48, 4, 4]           5,760
      BatchNorm2d-92           [1024, 48, 4, 4]              96
 InvertedResidual-93           [1024, 48, 4, 4]               0
           Conv2d-94          [1024, 144, 4, 4]           6,912
      BatchNorm2d-95          [1024, 144, 4, 4]             288
        Hardswish-96          [1024, 144, 4, 4]               0
           Conv2d-97          [1024, 144, 4, 4]           3,600
      BatchNorm2d-98          [1024, 144, 4, 4]             288
        Hardswish-99          [1024, 144, 4, 4]               0
AdaptiveAvgPool2d-100          [1024, 144, 1, 1]               0
          Conv2d-101           [1024, 40, 1, 1]           5,800
            ReLU-102           [1024, 40, 1, 1]               0
          Conv2d-103          [1024, 144, 1, 1]           5,904
     Hardsigmoid-104          [1024, 144, 1, 1]               0
SqueezeExcitation-105          [1024, 144, 4, 4]               0
          Conv2d-106           [1024, 48, 4, 4]           6,912
     BatchNorm2d-107           [1024, 48, 4, 4]              96
InvertedResidual-108           [1024, 48, 4, 4]               0
          Conv2d-109          [1024, 288, 4, 4]          13,824
     BatchNorm2d-110          [1024, 288, 4, 4]             576
       Hardswish-111          [1024, 288, 4, 4]               0
          Conv2d-112          [1024, 288, 2, 2]           7,200
     BatchNorm2d-113          [1024, 288, 2, 2]             576
       Hardswish-114          [1024, 288, 2, 2]               0
AdaptiveAvgPool2d-115          [1024, 288, 1, 1]               0
          Conv2d-116           [1024, 72, 1, 1]          20,808
            ReLU-117           [1024, 72, 1, 1]               0
          Conv2d-118          [1024, 288, 1, 1]          21,024
     Hardsigmoid-119          [1024, 288, 1, 1]               0
SqueezeExcitation-120          [1024, 288, 2, 2]               0
          Conv2d-121           [1024, 96, 2, 2]          27,648
     BatchNorm2d-122           [1024, 96, 2, 2]             192
InvertedResidual-123           [1024, 96, 2, 2]               0
          Conv2d-124          [1024, 576, 2, 2]          55,296
     BatchNorm2d-125          [1024, 576, 2, 2]           1,152
       Hardswish-126          [1024, 576, 2, 2]               0
          Conv2d-127          [1024, 576, 2, 2]          14,400
     BatchNorm2d-128          [1024, 576, 2, 2]           1,152
       Hardswish-129          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-130          [1024, 576, 1, 1]               0
          Conv2d-131          [1024, 144, 1, 1]          83,088
            ReLU-132          [1024, 144, 1, 1]               0
          Conv2d-133          [1024, 576, 1, 1]          83,520
     Hardsigmoid-134          [1024, 576, 1, 1]               0
SqueezeExcitation-135          [1024, 576, 2, 2]               0
          Conv2d-136           [1024, 96, 2, 2]          55,296
     BatchNorm2d-137           [1024, 96, 2, 2]             192
InvertedResidual-138           [1024, 96, 2, 2]               0
          Conv2d-139          [1024, 576, 2, 2]          55,296
     BatchNorm2d-140          [1024, 576, 2, 2]           1,152
       Hardswish-141          [1024, 576, 2, 2]               0
          Conv2d-142          [1024, 576, 2, 2]          14,400
     BatchNorm2d-143          [1024, 576, 2, 2]           1,152
       Hardswish-144          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-145          [1024, 576, 1, 1]               0
          Conv2d-146          [1024, 144, 1, 1]          83,088
            ReLU-147          [1024, 144, 1, 1]               0
          Conv2d-148          [1024, 576, 1, 1]          83,520
     Hardsigmoid-149          [1024, 576, 1, 1]               0
SqueezeExcitation-150          [1024, 576, 2, 2]               0
          Conv2d-151           [1024, 96, 2, 2]          55,296
     BatchNorm2d-152           [1024, 96, 2, 2]             192
InvertedResidual-153           [1024, 96, 2, 2]               0
          Conv2d-154          [1024, 576, 2, 2]          55,296
     BatchNorm2d-155          [1024, 576, 2, 2]           1,152
       Hardswish-156          [1024, 576, 2, 2]               0
AdaptiveAvgPool2d-157          [1024, 576, 1, 1]               0
          Linear-158               [1024, 1024]         590,848
       Hardswish-159               [1024, 1024]               0
         Dropout-160               [1024, 1024]               0
          Linear-161                [1024, 340]         348,500
================================================================
Total params: 1,866,356
Trainable params: 1,866,356
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 48.00
Forward/backward pass size (MB): 2979.22
Params size (MB): 7.12
Estimated Total Size (MB): 3034.34
----------------------------------------------------------------

🐶4.5 train-DDP.py

需要注意的是,train-DDP.py中包含许多训练策略:

  • DDP分布式训练(单机双卡);
  • AMP混合精度训练;
  • 学习率衰减;
  • 早停;
  • 断点继续训练。
# python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True
# Watch Training Log:tensorboard --logdir=tensorboard_dir
from tqdm import tqdm
import torch
import torch.nn.parallel
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
import time
import os
import torch.optim
import torch.utils.data
import torch.nn as nn
from collections import OrderedDict
from model import CustomMobileNetV3
from getdata import MyData
from torch.cuda.amp import GradScaler
from option import get_args
opt = get_args()
dist.init_process_group(backend='nccl', init_method='env://')

os.makedirs(opt.checkpoints, exist_ok=True)


def train(gpu):
    rank = dist.get_rank()
    model = CustomMobileNetV3()
    model.cuda(gpu)
    criterion = nn.CrossEntropyLoss().to(gpu)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    scaler = GradScaler(enabled=opt.use_mix_precision)  

    dataloaders = MyData()
    train_loader = dataloaders['train']
    test_loader = dataloaders['val']
    
    if opt.use_lr_scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
    
    start_time = time.time()
    best_val_acc = 0.0
    no_improve_epochs = 0
    early_stopping_patience = 6  # Early Stopping Patience
    
    """breakckpt resume"""
    if opt.resume:
        checkpoint = torch.load(opt.resume_ckpt)
        print('Loading checkpoint from:', opt.resume_ckpt)
        new_state_dict = OrderedDict()      # Create a new ordered dictionary and remove prefixes
        for k, v in checkpoint['model'].items():
            name = k[7:]                    # Remove 'module.' To match the original model definition
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict, strict=False)     # Load a new state dictionary
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']                       # Set the starting epoch
        if opt.use_lr_scheduler:
            scheduler.load_state_dict(checkpoint['scheduler'])
    else:
        start_epoch = 0
        
    for epoch in range(start_epoch + 1, opt.epochs):
        tqdm_trainloader = tqdm(train_loader, desc=f'Epoch {epoch}')
        running_loss, running_correct_top1, running_correct_top3, running_correct_top5 = 0.0, 0.0, 0.0, 0.0
        total_samples = 0
        for i, (images, target) in enumerate(tqdm_trainloader if rank == 0 else train_loader, 0):
            images = images.to(gpu)
            target = target.to(gpu)

            with torch.cuda.amp.autocast(enabled=opt.use_mix_precision):

                output = model(images)
                loss = criterion(output, target)

                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update() 

                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(output.data, 1)
                running_correct_top1  += (predicted == target).sum().item()
                _, predicted_top3 = torch.topk(output.data, 3, dim=1)
                _, predicted_top5 = torch.topk(output.data, 5, dim=1)
                running_correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()
                running_correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()
                total_samples += target.size(0)
            
        state = {'epoch': epoch,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()}
        
        if rank == 0:
            current_lr = scheduler.get_last_lr()[0] if opt.use_lr_scheduler else opt.lr
            print(f'[Epoch {epoch}]  '
                    f'[Train Loss: {running_loss / len(train_loader.dataset):.6f}]  '
                    f'[Train Top-1 Acc: {running_correct_top1 / len(train_loader.dataset):.6f}]  '
                    f'[Train Top-3 Acc: {running_correct_top3 / len(train_loader.dataset):.6f}]  '
                    f'[Train Top-5 Acc: {running_correct_top5 / len(train_loader.dataset):.6f}]  '
                    f'[Learning Rate: {current_lr:.6f}]  '
                    f'[Time: {time.time() - start_time:.6f} seconds]')
            writer.add_scalar('Train/Loss', running_loss / len(train_loader.dataset), epoch)
            writer.add_scalar('Train/Top-1 Accuracy', running_correct_top1 / len(train_loader.dataset), epoch)
            writer.add_scalar('Train/Top-3 Accuracy', running_correct_top3 / len(train_loader.dataset), epoch)
            writer.add_scalar('Train/Top-5 Accuracy', running_correct_top5 / len(train_loader.dataset), epoch)
            writer.add_scalar('Train/Learning Rate', current_lr, epoch)
            
            torch.save(state, f'{opt.checkpoints}model_epoch_{epoch}.pth')
            # dist.barrier()
            
        tqdm_trainloader.close()
        
        if opt.use_lr_scheduler:    # Learning-rate Scheduler
            scheduler.step()
        
        acc_top1 = valid(test_loader, model, epoch, gpu, rank)
        if acc_top1 is not None:
            if acc_top1 > best_val_acc:
                best_val_acc = acc_top1
                no_improve_epochs = 0
                torch.save(state, f'{opt.checkpoints}/model_best.pth')
            else:
                no_improve_epochs += 1
                if no_improve_epochs >= early_stopping_patience:
                    print(f'Early stopping triggered after {early_stopping_patience} epochs without improvement.')
                    break
        else:
            print("Warning: acc_top1 is None, skipping this epoch.")
        
    dist.destroy_process_group()

def valid(val_loader, model, epoch, gpu, rank):
    model.eval()
    correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu), torch.tensor(0.).to(gpu)
    with torch.no_grad():
        tqdm_valloader = tqdm(val_loader, desc=f'Epoch {epoch}')
        for i, (images, target) in enumerate(tqdm_valloader, 0) :
            images = images.to(gpu)
            target = target.to(gpu)
            output = model(images)
            total += target.size(0)
            correct_top1  += (output.argmax(1) == target).type(torch.float).sum()
            _, predicted_top3 = torch.topk(output, 3, dim=1)
            _, predicted_top5 = torch.topk(output, 5, dim=1)
            correct_top3 += (predicted_top3[:, :3] == target.unsqueeze(1).expand_as(predicted_top3)).sum().item()
            correct_top5 += (predicted_top5[:, :5] == target.unsqueeze(1).expand_as(predicted_top5)).sum().item()
            
    dist.reduce(total, 0, op=dist.ReduceOp.SUM)     # Group communication reduce operation (change to allreduce if Gloo)
    dist.reduce(correct_top1, 0, op=dist.ReduceOp.SUM)
    dist.reduce(correct_top3, 0, op=dist.ReduceOp.SUM)
    dist.reduce(correct_top5, 0, op=dist.ReduceOp.SUM)

    if rank == 0:
        print(f'[Epoch {epoch}]  '
                f'[Val Top-1 Acc: {correct_top1 / total:.6f}]  '
                f'[Val Top-3 Acc: {correct_top3 / total:.6f}]  '
                f'[Val Top-5 Acc: {correct_top5 / total:.6f}]')
        writer.add_scalar('Validation/Top-1 Accuracy', correct_top1 / total, epoch)
        writer.add_scalar('Validation/Top-3 Accuracy', correct_top3 / total, epoch)
        writer.add_scalar('Validation/Top-5 Accuracy', correct_top5 / total, epoch)

    return float(correct_top1 / total)  # Return top 1 precision
    tqdm_valloader.close()


def main():
    train(opt.local_rank)


if __name__ == '__main__':
    writer = SummaryWriter(log_dir=opt.tensorboard_dir)
    main()
    writer.close()

在终端使用下面命令可以启动多卡分布式训练:

python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True

相关参数含义如下:

  • nproc_per_node:显卡数量
  • nnodes:机器数量
  • node_rank:机器编号
  • master_addr:机器ip地址
  • master_port:机器端口

如果使用nohup启动训练会存在一个bug

W0914 18:33:15.081479 140031432897728 torch/distributed/elastic/agent/server/api.py:741] Received Signals.SIGHUP death signal, shutting down workers
W0914 18:33:15.085310 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685186 closing signal SIGHUP
W0914 18:33:15.085644 140031432897728 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1685192 closing signal SIGHUP

具体原因可以参考pytorch官方的discuss:DDP Error: torch.distributed.elastic.agent.server.api:Received 1 death signal, shutting down workers

我们可以使用tmux解决这个问题。

  1. 安装tmuxsudo apt-get install tmux
  2. 新建会话:tmux new -s train-DDP(会话名称自定义)
  3. 激活虚拟环境:conda activate pytorch(虚拟环境以实际需要为准)
  4. 启动训练任务:python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.168.8.89" --master_port=12345 train-DDP.py --use_mix_precision True

tmux常用命令如下:

  • 查看当前全部的tmux会话:tmux ls
  • 新建会话:tmux new -s 会话名字
  • 重新进入会话:tmux attach -t 会话名字
  • kill会话:tmux kill-session -t 会话名字

本文训练过程中的日志如下图所示:
在这里插入图片描述
在这里插入图片描述
模型在第11轮发生早停。

🐶4.6 model_transfer.py

代码作用是将pth模型转为移动端的ptl格式和onnx格式,方便模型端侧部署。

from torch.utils.mobile_optimizer import optimize_for_mobile
import torch
from model import CustomMobileNetV3
import onnx
from onnxsim import simplify
from torch.autograd import Variable
from option import get_args
opt = get_args()


model = CustomMobileNetV3()
model.load_state_dict(torch.load(f'{opt.checkpoints}model_best.pth', map_location='cpu')['model'])
model.eval()
print("Model loaded successfully.")


"""Save .pth format model"""
torch.save(model, f'{opt.checkpoints}/model.pth')


"""Save .ptl format model"""
example = torch.rand(1, 3, 64, 64)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(f'{opt.checkpoints}model.ptl')


"""Save .onnx format model"""
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, opt.loadsize, opt.loadsize))
torch.onnx.export(model, input, f'{opt.checkpoints}model.onnx', input_names=input_name, output_names=output_name, verbose=True)
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(f'{opt.checkpoints}model.onnx')), f'{opt.checkpoints}model.onnx')   # Perform shape judgment
# simplified model
model_onnx = onnx.load(f'{opt.checkpoints}model.onnx')
model_simplified, check = simplify(model_onnx)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simplified, f'{opt.checkpoints}model_simplified.onnx')

🐶4.7 evaluate.py

代码定义了三个函数:

  • evaluate_image_single:对单张图像进行预测
  • evaluate_image_dir:对文件夹图像进行预测
  • evaluate_onnx_model:onnx模型对图像进行预测

代码提供了多个可视化图像与评估指标。包括 混淆矩阵、F1score 等。

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch.nn.functional as F
import torch.utils.data
import onnxruntime
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
from tqdm import tqdm
from getdata import mean, std, class_names
from option import get_args
opt = get_args()
device = 'cuda:1'

"""Predicting a single image"""
def evaluate_image_single(img_path, transform_test, model, class_names, top_k):
    
    image = Image.open(img_path).convert('RGB')
    img = transform_test(image).to(device)
    img = img.unsqueeze_(0)
    out = model(img)
    pred_softmax = F.softmax(out, dim=1)
    top_n, top_n_indices = torch.topk(pred_softmax, top_k)
    
    confs = top_n[0].cpu().detach().numpy().tolist()
    class_names_top = [class_names[i] for i in top_n_indices[0]]
    
    for i in range(top_k):
        print(f'Pre: {class_names_top[i]}   Conf: {confs[i]:.3f}')
    
    confs_max = confs[0]
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.axis('off')
    plt.title(f'Pre: {class_names_top[0]}   Conf: {confs_max:.3f}')
    plt.imshow(image)
    
    sorted_pairs = sorted(zip(class_names_top, confs), key=lambda x: x[1], reverse=True)
    sorted_class_names_top, sorted_confs = zip(*sorted_pairs)
    
    plt.subplot(1, 2, 2)
    bars = plt.bar(sorted_class_names_top, sorted_confs, color='lightcoral')
    plt.xlabel('Class Names')
    plt.ylabel('Confidence')
    plt.title('Top 5 Predictions (Descending Order)')
    plt.xticks(rotation=45)
    plt.ylim(0, 1)
    plt.tight_layout()
    for bar, conf in zip(bars, sorted_confs):
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')
    plt.savefig('predict_image_with_bars.jpg')


"""Predicting folder images"""
def evaluate_image_dir(model, dataloader, class_names):
    model.eval()
    all_preds = []
    all_labels = []
    correct_top1, correct_top3, correct_top5, total = torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device)
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            total += labels.size(0)
            correct_top1  += (outputs.argmax(1) == labels).type(torch.float).sum()
            _, predicted_top3 = torch.topk(outputs, 3, dim=1)
            _, predicted_top5 = torch.topk(outputs, 5, dim=1)
            correct_top3 += (predicted_top3[:, :3] == labels.unsqueeze(1).expand_as(predicted_top3)).sum().item()
            correct_top5 += (predicted_top5[:, :5] == labels.unsqueeze(1).expand_as(predicted_top5)).sum().item()
            
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds)
            all_labels.extend(labels)
    
    all_preds = torch.tensor(all_preds)
    all_labels = torch.tensor(all_labels)
    
    top1 = correct_top1 / total
    top3 = correct_top3 / total
    top5 = correct_top5 / total
    print(f"Top-1 Accuracy: {top1:.4f}")
    print(f"Top-3 Accuracy: {top3:.4f}")
    print(f"Top-5 Accuracy: {top5:.4f}")
    
    accuracy = accuracy_score(all_labels.cpu().numpy(), all_preds.cpu().numpy())
    precision = precision_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
    recall = recall_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
    f1 = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='macro')
    
    cm = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())
    report = classification_report(all_labels.cpu().numpy(), all_preds.cpu().numpy(), target_names=class_names)
    
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')
    print(report)

    plt.figure(figsize=(100, 100))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, annot_kws={"size": 8})
    plt.xticks(rotation=90) 
    plt.yticks(rotation=0)  
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.jpg')

"""Using .onnx model to predict images"""
def evaluate_onnx_model(img_path, data_transform, onnx_model_path, class_names, top_k=5):
    ort_session = onnxruntime.InferenceSession(onnx_model_path)
    img_pil = Image.open(img_path).convert('RGB')
    input_img = data_transform(img_pil)
    input_tensor = input_img.unsqueeze(0).numpy()
    ort_inputs = {'input': input_tensor}
    out = ort_session.run(['output'], ort_inputs)[0]
    
    def softmax(x):
        return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
    
    prob_dist = softmax(out)
    result_dict = {label: float(prob_dist[0][i]) for i, label in enumerate(class_names)}
    result_dict = dict(sorted(result_dict.items(), key=lambda item: item[1], reverse=True))

    for key, value in list(result_dict.items())[:top_k]:
        print(f'Pre: {key}   Conf: {value:.3f}')

    confs_max = list(result_dict.values())[0]
    class_names_top = list(result_dict.keys())

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.axis('off')
    plt.title(f'Pre: {class_names_top[0]}   Conf: {confs_max:.3f}')
    plt.imshow(img_pil)

    plt.subplot(1, 2, 2)
    bars = plt.bar(class_names_top[:top_k], list(result_dict.values())[:top_k], color='lightcoral')
    plt.xlabel('Class Names')
    plt.ylabel('Confidence')
    plt.title('Top 5 Predictions (Descending Order)')
    plt.xticks(rotation=45)
    plt.ylim(0, 1)
    plt.tight_layout()
    for bar, conf in zip(bars, list(result_dict.values())[:top_k]):
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{conf:.3f}', ha='center', va='bottom')
    plt.savefig('predict_image_with_bars.jpg')



if __name__ == '__main__':
    data_transform = transforms.Compose([transforms.Resize((opt.loadsize, opt.loadsize)), transforms.ToTensor(),transforms.Normalize(mean, std)])
    image_datasets = ImageFolder(opt.dataset_test, data_transform)
    dataloaders = DataLoader(image_datasets, batch_size=512, shuffle=True)
    
    ptl_model_path = opt.checkpoints + 'model.ptl'
    pth_model_path = opt.checkpoints + 'model.pth'
    onnx_model_path = opt.checkpoints + 'model.onnx'
    
    ptl_model = torch.jit.load(ptl_model_path).to(device)
    pth_model = torch.load(pth_model_path).to(device)
    
    evaluate_image_single(opt.test_img_path, data_transform, pth_model, class_names, top_k=5)     # Predicting a single image
    # evaluate_image_dir(pth_model, dataloaders, class_names)     # Predicting folder images
    # evaluate_onnx_model(opt.test_img_path, data_transform, onnx_model_path, class_names, top_k=5)   # Predicting a single image

使用evaluate_image_single函数对datasets/test/zigzag/zigzag-4508464694951936.png图片进行预测,结果如下:
在这里插入图片描述
使用evaluate_image_dir函数对datasets/test路径内的图像进行预测,结果如下:

Top-1 Accuracy: 0.6833
Top-3 Accuracy: 0.8521
Top-5 Accuracy: 0.8933
Accuracy: 0.6833
Precision: 0.6875
Recall: 0.6833
F1 Score: 0.6817
                         precision    recall  f1-score   support

       The Eiffel Tower       0.83      0.88      0.85      1000
The Great Wall of China       0.47      0.36      0.41      1000
          The Mona Lisa       0.68      0.86      0.76      1000
               airplane       0.83      0.74      0.78      1000
            alarm clock       0.76      0.76      0.76      1000
              ambulance       0.70      0.65      0.67      1000
                  angel       0.87      0.78      0.82      1000
       animal migration       0.47      0.66      0.55      1000
                    ant       0.77      0.74      0.75      1000
                  anvil       0.80      0.66      0.72      1000
                  apple       0.82      0.85      0.83      1000
                    arm       0.74      0.69      0.71      1000
              asparagus       0.54      0.44      0.48      1000
                    axe       0.69      0.67      0.68      1000
               backpack       0.61      0.75      0.67      1000
                 banana       0.68      0.72      0.70      1000
                bandage       0.83      0.71      0.77      1000
                   barn       0.66      0.68      0.67      1000
               baseball       0.77      0.71      0.74      1000
           baseball bat       0.75      0.73      0.74      1000
                 basket       0.71      0.62      0.66      1000
             basketball       0.62      0.72      0.66      1000
                    bat       0.79      0.62      0.69      1000
                bathtub       0.60      0.64      0.62      1000
                  beach       0.58      0.65      0.61      1000
                   bear       0.46      0.31      0.37      1000
                  beard       0.56      0.73      0.63      1000
                    bed       0.80      0.67      0.73      1000
                    bee       0.82      0.74      0.78      1000
                   belt       0.78      0.55      0.64      1000
                  bench       0.59      0.53      0.56      1000
                bicycle       0.73      0.72      0.72      1000
             binoculars       0.74      0.77      0.76      1000
                   bird       0.47      0.43      0.45      1000
          birthday cake       0.52      0.64      0.57      1000
             blackberry       0.46      0.42      0.44      1000
              blueberry       0.58      0.47      0.52      1000
                   book       0.72      0.78      0.75      1000
              boomerang       0.73      0.70      0.71      1000
              bottlecap       0.58      0.54      0.56      1000
                 bowtie       0.87      0.86      0.86      1000
               bracelet       0.68      0.60      0.64      1000
                  brain       0.59      0.60      0.59      1000
                  bread       0.54      0.63      0.58      1000
                 bridge       0.61      0.64      0.63      1000
               broccoli       0.58      0.70      0.64      1000
                  broom       0.56      0.68      0.61      1000
                 bucket       0.62      0.66      0.64      1000
              bulldozer       0.69      0.70      0.70      1000
                    bus       0.56      0.42      0.48      1000
                   bush       0.47      0.65      0.55      1000
              butterfly       0.86      0.88      0.87      1000
                 cactus       0.69      0.87      0.77      1000
                   cake       0.53      0.42      0.47      1000
             calculator       0.76      0.82      0.79      1000
               calendar       0.54      0.50      0.52      1000
                  camel       0.82      0.84      0.83      1000
                 camera       0.87      0.74      0.80      1000
             camouflage       0.23      0.43      0.30      1000
               campfire       0.72      0.77      0.75      1000
                 candle       0.75      0.73      0.74      1000
                 cannon       0.77      0.69      0.72      1000
                  canoe       0.67      0.63      0.65      1000
                    car       0.65      0.63      0.64      1000
                 carrot       0.75      0.82      0.78      1000
                 castle       0.79      0.72      0.75      1000
                    cat       0.69      0.66      0.68      1000
            ceiling fan       0.83      0.64      0.72      1000
             cell phone       0.62      0.60      0.61      1000
                  cello       0.51      0.67      0.58      1000
                  chair       0.83      0.80      0.81      1000
             chandelier       0.74      0.71      0.73      1000
                 church       0.72      0.67      0.69      1000
                 circle       0.53      0.86      0.66      1000
               clarinet       0.53      0.63      0.58      1000
                  clock       0.86      0.77      0.82      1000
                  cloud       0.73      0.69      0.71      1000
             coffee cup       0.67      0.43      0.52      1000
                compass       0.69      0.78      0.73      1000
               computer       0.79      0.62      0.69      1000
                 cookie       0.68      0.80      0.74      1000
                 cooler       0.47      0.33      0.38      1000
                  couch       0.76      0.82      0.79      1000
                    cow       0.70      0.57      0.63      1000
                   crab       0.70      0.72      0.71      1000
                 crayon       0.44      0.52      0.47      1000
              crocodile       0.65      0.57      0.60      1000
                  crown       0.87      0.87      0.87      1000
            cruise ship       0.76      0.69      0.73      1000
                    cup       0.43      0.50      0.47      1000
                diamond       0.73      0.88      0.80      1000
             dishwasher       0.56      0.47      0.51      1000
           diving board       0.53      0.54      0.53      1000
                    dog       0.50      0.41      0.45      1000
                dolphin       0.79      0.59      0.68      1000
                  donut       0.75      0.88      0.81      1000
                   door       0.69      0.72      0.70      1000
                 dragon       0.52      0.42      0.47      1000
                dresser       0.75      0.65      0.70      1000
                  drill       0.78      0.71      0.75      1000
                  drums       0.71      0.68      0.70      1000
                   duck       0.68      0.49      0.57      1000
               dumbbell       0.78      0.80      0.79      1000
                    ear       0.81      0.75      0.78      1000
                  elbow       0.74      0.62      0.68      1000
               elephant       0.66      0.66      0.66      1000
               envelope       0.87      0.94      0.90      1000
                 eraser       0.50      0.61      0.55      1000
                    eye       0.83      0.85      0.84      1000
             eyeglasses       0.84      0.80      0.82      1000
                   face       0.62      0.64      0.63      1000
                    fan       0.76      0.60      0.67      1000
                feather       0.58      0.60      0.59      1000
                  fence       0.67      0.71      0.69      1000
                 finger       0.70      0.63      0.67      1000
           fire hydrant       0.56      0.64      0.60      1000
              fireplace       0.74      0.67      0.71      1000
              firetruck       0.71      0.50      0.59      1000
                   fish       0.89      0.85      0.87      1000
               flamingo       0.69      0.75      0.72      1000
             flashlight       0.80      0.82      0.81      1000
             flip flops       0.64      0.75      0.69      1000
             floor lamp       0.77      0.70      0.74      1000
                 flower       0.79      0.83      0.81      1000
          flying saucer       0.65      0.64      0.64      1000
                   foot       0.68      0.66      0.67      1000
                   fork       0.81      0.79      0.80      1000
                   frog       0.46      0.47      0.47      1000
             frying pan       0.78      0.76      0.77      1000
                 garden       0.59      0.63      0.61      1000
            garden hose       0.42      0.28      0.33      1000
                giraffe       0.87      0.80      0.84      1000
                 goatee       0.72      0.73      0.72      1000
              golf club       0.60      0.62      0.61      1000
                 grapes       0.68      0.65      0.66      1000
                  grass       0.59      0.83      0.69      1000
                 guitar       0.68      0.50      0.58      1000
              hamburger       0.66      0.83      0.73      1000
                 hammer       0.71      0.75      0.73      1000
                   hand       0.83      0.83      0.83      1000
                   harp       0.83      0.78      0.80      1000
                    hat       0.72      0.71      0.72      1000
             headphones       0.92      0.91      0.92      1000
               hedgehog       0.73      0.74      0.73      1000
             helicopter       0.81      0.83      0.82      1000
                 helmet       0.63      0.66      0.64      1000
                hexagon       0.70      0.73      0.72      1000
            hockey puck       0.59      0.61      0.60      1000
           hockey stick       0.59      0.54      0.56      1000
                  horse       0.53      0.85      0.65      1000
               hospital       0.80      0.68      0.74      1000
        hot air balloon       0.79      0.72      0.75      1000
                hot dog       0.60      0.63      0.62      1000
                hot tub       0.58      0.51      0.54      1000
              hourglass       0.86      0.87      0.87      1000
                  house       0.77      0.77      0.77      1000
            house plant       0.85      0.82      0.83      1000
              hurricane       0.39      0.45      0.42      1000
              ice cream       0.82      0.85      0.84      1000
                 jacket       0.75      0.72      0.74      1000
                   jail       0.71      0.72      0.71      1000
               kangaroo       0.73      0.71      0.72      1000
                    key       0.71      0.76      0.74      1000
               keyboard       0.50      0.48      0.49      1000
                   knee       0.63      0.68      0.65      1000
                 ladder       0.88      0.91      0.89      1000
                lantern       0.70      0.53      0.60      1000
                 laptop       0.63      0.80      0.71      1000
                   leaf       0.73      0.71      0.72      1000
                    leg       0.58      0.50      0.54      1000
             light bulb       0.69      0.79      0.73      1000
             lighthouse       0.71      0.74      0.72      1000
              lightning       0.76      0.69      0.72      1000
                   line       0.55      0.82      0.66      1000
                   lion       0.70      0.76      0.73      1000
               lipstick       0.59      0.69      0.63      1000
                lobster       0.61      0.47      0.53      1000
               lollipop       0.76      0.85      0.80      1000
                mailbox       0.75      0.66      0.70      1000
                    map       0.65      0.73      0.68      1000
                 marker       0.39      0.16      0.23      1000
                matches       0.52      0.47      0.49      1000
              megaphone       0.80      0.70      0.75      1000
                mermaid       0.76      0.84      0.80      1000
             microphone       0.64      0.73      0.68      1000
              microwave       0.79      0.75      0.77      1000
                 monkey       0.59      0.56      0.57      1000
                   moon       0.69      0.60      0.64      1000
               mosquito       0.48      0.55      0.51      1000
              motorbike       0.64      0.62      0.63      1000
               mountain       0.74      0.80      0.77      1000
                  mouse       0.53      0.46      0.49      1000
              moustache       0.75      0.72      0.73      1000
                  mouth       0.72      0.76      0.74      1000
                    mug       0.54      0.65      0.59      1000
               mushroom       0.66      0.76      0.70      1000
                   nail       0.58      0.66      0.62      1000
               necklace       0.75      0.63      0.68      1000
                   nose       0.69      0.75      0.72      1000
                  ocean       0.54      0.54      0.54      1000
                octagon       0.71      0.62      0.66      1000
                octopus       0.89      0.83      0.86      1000
                  onion       0.75      0.68      0.71      1000
                   oven       0.50      0.39      0.44      1000
                    owl       0.68      0.65      0.67      1000
              paint can       0.51      0.49      0.50      1000
             paintbrush       0.58      0.63      0.61      1000
              palm tree       0.73      0.83      0.78      1000
                  panda       0.66      0.62      0.64      1000
                  pants       0.75      0.68      0.71      1000
             paper clip       0.75      0.78      0.76      1000
              parachute       0.81      0.79      0.80      1000
                 parrot       0.54      0.59      0.56      1000
               passport       0.60      0.55      0.58      1000
                 peanut       0.70      0.73      0.71      1000
                   pear       0.72      0.80      0.76      1000
                   peas       0.70      0.56      0.62      1000
                 pencil       0.58      0.60      0.59      1000
                penguin       0.69      0.78      0.73      1000
                  piano       0.65      0.66      0.65      1000
           pickup truck       0.60      0.64      0.62      1000
          picture frame       0.68      0.89      0.77      1000
                    pig       0.77      0.56      0.65      1000
                 pillow       0.60      0.58      0.59      1000
              pineapple       0.80      0.85      0.82      1000
                  pizza       0.65      0.77      0.70      1000
                 pliers       0.69      0.55      0.61      1000
             police car       0.67      0.68      0.67      1000
                   pond       0.40      0.47      0.43      1000
                   pool       0.51      0.23      0.32      1000
               popsicle       0.70      0.79      0.75      1000
               postcard       0.74      0.58      0.65      1000
                 potato       0.54      0.40      0.46      1000
           power outlet       0.61      0.72      0.66      1000
                  purse       0.64      0.69      0.66      1000
                 rabbit       0.66      0.80      0.72      1000
                raccoon       0.43      0.44      0.44      1000
                  radio       0.71      0.59      0.64      1000
                   rain       0.77      0.90      0.83      1000
                rainbow       0.79      0.92      0.85      1000
                   rake       0.69      0.67      0.68      1000
         remote control       0.67      0.68      0.67      1000
             rhinoceros       0.65      0.75      0.69      1000
                  river       0.66      0.61      0.64      1000
         roller coaster       0.70      0.52      0.60      1000
           rollerskates       0.86      0.83      0.84      1000
               sailboat       0.84      0.87      0.86      1000
               sandwich       0.50      0.68      0.57      1000
                    saw       0.81      0.83      0.82      1000
              saxophone       0.79      0.77      0.78      1000
             school bus       0.51      0.44      0.47      1000
               scissors       0.80      0.84      0.82      1000
               scorpion       0.70      0.76      0.73      1000
            screwdriver       0.58      0.62      0.60      1000
             sea turtle       0.79      0.73      0.76      1000
                see saw       0.85      0.79      0.82      1000
                  shark       0.72      0.72      0.72      1000
                  sheep       0.75      0.80      0.77      1000
                   shoe       0.73      0.75      0.74      1000
                 shorts       0.67      0.76      0.71      1000
                 shovel       0.62      0.73      0.67      1000
                   sink       0.62      0.76      0.68      1000
             skateboard       0.83      0.85      0.84      1000
                  skull       0.86      0.83      0.85      1000
             skyscraper       0.65      0.56      0.60      1000
           sleeping bag       0.55      0.59      0.57      1000
            smiley face       0.74      0.80      0.77      1000
                  snail       0.79      0.90      0.84      1000
                  snake       0.65      0.66      0.65      1000
                snorkel       0.79      0.73      0.76      1000
              snowflake       0.79      0.84      0.81      1000
                snowman       0.83      0.90      0.86      1000
            soccer ball       0.69      0.70      0.69      1000
                   sock       0.77      0.75      0.76      1000
              speedboat       0.65      0.65      0.65      1000
                 spider       0.72      0.79      0.76      1000
                  spoon       0.69      0.57      0.63      1000
            spreadsheet       0.67      0.62      0.65      1000
                 square       0.52      0.84      0.65      1000
               squiggle       0.41      0.40      0.40      1000
               squirrel       0.71      0.74      0.72      1000
                 stairs       0.90      0.91      0.90      1000
                   star       0.93      0.91      0.92      1000
                  steak       0.53      0.46      0.49      1000
                 stereo       0.61      0.68      0.64      1000
            stethoscope       0.87      0.75      0.81      1000
               stitches       0.71      0.79      0.75      1000
              stop sign       0.86      0.88      0.87      1000
                  stove       0.71      0.66      0.69      1000
             strawberry       0.80      0.80      0.80      1000
            streetlight       0.75      0.71      0.73      1000
            string bean       0.51      0.39      0.44      1000
              submarine       0.83      0.67      0.74      1000
               suitcase       0.75      0.57      0.64      1000
                    sun       0.87      0.88      0.87      1000
                   swan       0.69      0.67      0.68      1000
                sweater       0.68      0.65      0.67      1000
              swing set       0.89      0.90      0.89      1000
                  sword       0.85      0.81      0.83      1000
                t-shirt       0.80      0.78      0.79      1000
                  table       0.73      0.76      0.74      1000
                 teapot       0.82      0.77      0.80      1000
             teddy-bear       0.66      0.74      0.70      1000
              telephone       0.67      0.54      0.60      1000
             television       0.88      0.85      0.86      1000
         tennis racquet       0.86      0.74      0.80      1000
                   tent       0.80      0.77      0.78      1000
                  tiger       0.53      0.47      0.50      1000
                toaster       0.59      0.70      0.64      1000
                    toe       0.67      0.63      0.65      1000
                 toilet       0.74      0.80      0.77      1000
                  tooth       0.72      0.74      0.73      1000
             toothbrush       0.74      0.76      0.75      1000
             toothpaste       0.54      0.56      0.55      1000
                tornado       0.63      0.69      0.66      1000
                tractor       0.65      0.71      0.68      1000
          traffic light       0.84      0.84      0.84      1000
                  train       0.61      0.74      0.67      1000
                   tree       0.72      0.75      0.73      1000
               triangle       0.87      0.93      0.90      1000
               trombone       0.58      0.48      0.53      1000
                  truck       0.50      0.41      0.45      1000
                trumpet       0.65      0.49      0.56      1000
               umbrella       0.91      0.86      0.88      1000
              underwear       0.83      0.64      0.72      1000
                    van       0.46      0.58      0.51      1000
                   vase       0.82      0.67      0.74      1000
                 violin       0.52      0.52      0.52      1000
        washing machine       0.74      0.78      0.76      1000
             watermelon       0.56      0.66      0.61      1000
             waterslide       0.57      0.70      0.63      1000
                  whale       0.71      0.74      0.72      1000
                  wheel       0.82      0.50      0.62      1000
               windmill       0.82      0.77      0.79      1000
            wine bottle       0.77      0.81      0.79      1000
             wine glass       0.86      0.85      0.86      1000
             wristwatch       0.72      0.74      0.73      1000
                   yoga       0.60      0.57      0.58      1000
                  zebra       0.73      0.66      0.69      1000
                 zigzag       0.73      0.75      0.74      1000

               accuracy                           0.68    340000
              macro avg       0.69      0.68      0.68    340000
           weighted avg       0.69      0.68      0.68    340000

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2145447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

链式前向星建图

回顾邻接局矩阵和邻接表建图: ​ 在之前的图论基础中,我们提到了两种建图方式:邻接矩阵、邻接表。 邻接矩阵实现: int N; //所有节点个数 int Graph[N][N]; for(int i : Numbers){ //Numbers表示所有节点for(int j : Neighbor…

VC++以资源方式打开可执行文件

刚看一个资料说可以在VC中,以资源方式打开可执行文件,然后它如果包含对话框一些资源,会呈现出来,可以把其他程序界面上的控件直接拷贝到自己程序; 但是操作了一下没有成功, 先新建一个空对话框准备拷贝东…

【Linux】Linux的基本指令(1)

A clown is always a clown.💓💓💓 目录 ✨说在前面 🍋知识点一:Linux的背景 •🌰1.Unix发展的历史 •🌰2.Linux发展历史 •🌰3.企业应用现状 •🌰4.发行版本 &…

【protobuf】ProtoBuf的学习与使用⸺C++

W...Y的主页 😊 代码仓库分享💕 前言:之前我们学习了Linux与windows的protobuf安装,知道protobuf是做序列化操作的应用,今天我们来学习一下protobuf。 目录 ⼀、初识ProtoBuf 步骤1:创建.proto文件 步…

WLAN无线局域网

目录 概述 IEEE 802.11标准与WiFi的世代 ​编辑 无线控制器AC(Access Controller) 无线接入点AP(Access Point) PoE(Power Over Ethernet) PoE交换机 STA(Station) BSS&#x…

简单生活的快乐

小明经常会被问到一个问题:为什么他那么有钱却选择过一种简单、谦逊的生活。先从小明的早年经历说起吧,大概是他六到十三岁的时候,物质对他来说是非常重要的。他记得当妈妈给他买了一双昂贵的鞋子时,他特别兴奋,喜欢向…

GEE 案例:利用sentinel-2数据计算的NDVI指数对比植被退化情况

目录 简介 NDVI指数 数据 函数 ui.Chart.image.series(imageCollection, region, reducer, scale, xProperty) Arguments: Returns: ui.Chart 代码 结果 简介 利用sentinel-2数据计算的NDVI指数对比植被退化情况 NDVI指数 NDVI(Normalized Difference Ve…

武器检测系统源码分享

武器检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision …

压力测试Monkey命令参数和报告分析!

adb的操作命令格式一般为&#xff1a;adb shell monkey 命令参数 PART 01 常用参数 ⏩ -p <测试的包名列表> 用于约束限制&#xff0c;用此参数指定一个或多个包。指定包之后&#xff0c;Monkey将只允许系统启动指定的APP。如果不指定包&#xff0c;Monkey将允许系统…

【JVM】垃圾回收机制|死亡对象的判断算法|垃圾回收算法

思维导图 目录 1.找到谁是垃圾 1&#xff09;引用计数&#xff08;不是JVM采取的方式&#xff0c;而是Python/PHP的方案&#xff09; 2&#xff09;可达性分析&#xff08;是JVM采用的方案&#xff09; 2.释放对应的内存的策略 1&#xff09;标记-清除&#xff08;并不实…

信息安全数学基础(18)模重复平方计算法

前言 模重复平方计算法&#xff08;Modular Exponentiation by Squaring&#xff09;&#xff0c;也称为快速幂算法&#xff0c;是一种用于高效计算 abmodn 的算法&#xff0c;其中 a、b 和 n 是整数&#xff0c;且 b 可能非常大。这种算法通过减少乘法操作的次数来加速计算过程…

伦敦金的交易差价意味着什么?

在伦敦金投资市场上&#xff0c;点差是指交易平台的买入价&#xff08;买价&#xff09;和卖出价&#xff08;卖价&#xff09;之间的差额。对投资者来说&#xff0c;点差是交易成本的一部分&#xff0c;但它是经纪商的收入来源。点差代表伦敦金投资者在进入和退出交易时需要支…

Python 入门教程(4)数据类型 | 4.5、字符串类型

文章目录 一、字符串类型1、字符串的定义2、字符串索引3、字符串的基本操作4、字符串的编码5、字符串的不可变性6、总结 前言&#xff1a; 在Python中&#xff0c;字符串&#xff08;String&#xff09;是一种非常重要的数据类型&#xff0c;用于表示和存储文本信息。Python的字…

我的AI工具箱Tauri版-VideoIntroductionClipCut视频介绍混剪

本教程基于自研的AI工具箱Tauri版进行VideoIntroductionClipCut视频介绍混剪。 进入软件后可以直接搜索 VideoIntroductionClipCut 或者依次点击 Python音频技术/视频tools 进入该模块。 视频样片《Tara音乐介绍》 《我的AI工具箱Tauri版-VideoIntroductionClipCut视频介绍混…

excel VBA进行间比法设计

在品比试验大家多使用间比法试验设计&#xff0c;这里通过excel VBA实现间比法设计&#xff0c;代码如下&#xff1a; Sub 生成试验设计()Dim ws As Worksheet Dim rng As Range, rng2 As Range, rng3 As Range Dim cell As Range, lastcell As Range Dim rd As String, sn As…

SpringBootWeb增删改查入门案例

前言 为了快速入门一个SpringBootWeb项目&#xff0c;这里就将基础的增删改查的案例进行总结&#xff0c;作为对SpringBootMybatis的基础用法的一个巩固。 准备工作 需求说明 对员工表进行增删改查操作环境搭建 准备数据表 -- 员工管理(带约束) create table emp (id int …

论文阅读 | 基于流模型和可逆噪声层的鲁棒水印框架(AAAI 2023)

Flow-based Robust Watermarking with Invertible Noise Layer for Black-box DistortionsAAAI, 2023&#xff0c;新加坡国立大学&中国科学技术大学本论文提出一种基于流的鲁棒数字水印框架&#xff0c;该框架采用了可逆噪声层来抵御黑盒失真。 一、问题 基于深度神经网络…

spring boot admin集成,springboot2.x集成监控

服务端&#xff1a; 1. 新建monitor服务 pom依赖 <!-- 注意这些只是pom的核心东西&#xff0c;不是完整的pom.xml内容&#xff0c;不能直接使用&#xff0c;仅供参考使用 --><packaging>jar</packaging><dependencies><dependency><groupId&g…

STM32 芯片启动过程

目录 一、前言二、STM32 的启动模式三、STM32 启动文件分析1、栈 Stack2、堆 Heap3、中断向量表 Vectors3.1 中断响应流程 4、复位程序 Reset_Handler5、中断服务函数6、用户堆栈初始化 四、STM32 启动流程分析1、初始化 SP、PC 及中断向量表2、设置系统时钟3、初始化堆栈并进入…

【Linux】POSIX信号量与、基于环形队列实现的生产者消费者模型

目录 一、POSIX信号量概述 信号量的基本概念 信号量在临界区的作用 与互斥锁的比较 信号量的原理 信号量的优势 二、信号量的操作 1、初始化信号量&#xff1a;sem_init 2、信号量申请&#xff08;P操作&#xff09;&#xff1a;sem_wait 3、信号量的释放&#xff08…