基于PyTorch的视频分类实战

news2024/11/16 20:29:04

1、数据集下载

官方链接:https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads

百度网盘连接:

https://pan.baidu.com/s/1sSn--u_oLvTDjH-BgOAv_Q?pwd=xsri

提取码: xsri 

        官方链接有详细的数据集介绍,下载的是压缩包 ‘hmdb51_org.rar’,解压后里面是 51 个.rar 压缩包,每个压缩包名是一个类别,里面的是对应类别的视频片段(.avi 文件)。因为资源有限,这里只解压了 5 个类别的视频如图 1 所示:

图1 'hmdb5/org'

        这里新建了 ‘hmdb5’ 文件夹,并新建了 ‘org’ 子文件夹,然后把 ‘hmdb51_org’ 文件夹的 5 个子文件夹放到 ‘org’ 中。作为这次实践的源视频数据。

2、utils.py
        在这里先实现 utils.py,即取帧(get_frames)和存帧(store_frames)函数,取帧函数的功能为从视频中等间距抽取 n_frame 帧,并返回这些帧组成的列表。存帧函数的功能即为将帧列表按序存到 path 中。

import os
import cv2
import numpy as np


def get_frames(path, n_frames=1):
    """
    :param path: 视频文件路径
    :param n_frames: 读取的帧数
    :return: 读取的帧列表 frames
    """
    frames = []

    # 实例化一个用于捕获视频流的对象, 若参数为整数则用于读取摄像头视频, 若参数为字符串则用于读取视频文件
    v_cap = cv2.VideoCapture(path)

    '''
    cv2.CAP_PROP_FRAME_COUNT 是 cv2.VideoCapture 类的一个属性标识符,用于获取视频流或视频文件中的总帧数
    cv2.VideoCapture 的 get 方法用于获取视频流或视频文件的属性(返回值均为实数):
        propId 是属性标识符,整数:
            cv2.CAP_PROP_FRAME_WIDTH:视频的帧宽度(以像素为单位)
            cv2.CAP_PROP_FRAME_HEIGHT:视频的帧高度(以像素为单位)
            cv2.CAP_PROP_FPS:视频的帧率(每秒的帧数)
            cv2.CAP_PROP_POS_FRAMES:当前读取帧的位置(基于 0 的索引)
            cv2.CAP_PROP_POS_AVI_RATIO:视频文件的相对位置(播放进度)
            cv2.CAP_PROP_FRAME_COUNT:视频文件中的总帧数
    '''
    v_len = int(v_cap.get(propId=cv2.CAP_PROP_FRAME_COUNT))

    '''
    在指定区间返回等距的数字数组:
        start: 区间起点
        stop: 区间终点
        num: 采样数量
        endpoint: 默认为 True,若为 False 则区间不包括 stop
    '''
    frame_list = np.linspace(start=0, stop=v_len - 1, num=n_frames + 1, dtype=np.int16)

    for fn in range(v_len):
        # 读取下一帧。它返回两个值:一个布尔值 success 表示是否成功读取帧和一个数组 frame 表示读取到的帧。
        success, frame = v_cap.read()
        if success is False:
            continue
        if fn in frame_list:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
    v_cap.release()
    return frames


def store_frames(frames, path):
    """
    :param frames: 待保存为 jpg 图片的帧列表
    :param path: 存储路径
    :return:
    """
    for i, frame in enumerate(frames):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        path2img = os.path.join(path, "frame" + str(i) + ".jpg")
        cv2.imwrite(path2img, frame)

3、数据抽帧,并划分训练集和测试集

        先在 ‘hmdb5’ 文件夹中新建子文件夹 ‘train’ 和 ‘test',再运行以下代码即可数据抽帧,并划分训练集和测试集。       

import os
from utils import get_frames, store_frames

path = "hmdb5"
org_dir = "org"
org_path = os.path.join(path, org_dir)
categories_list = os.listdir(org_path)
# brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4

# 输出每个类别的视频数量
for c in categories_list:
	print("category:", c)
	p = os.path.join(org_path, c)
	video_list = os.listdir(p)
	print("number of videos:", len(video_list))
	print("-" * 50)
"""
category: brush_hair
number of videos: 107
--------------------------------------------------
category: cartwheel
number of videos: 107
--------------------------------------------------
category: clap
number of videos: 130
--------------------------------------------------
category: catch
number of videos: 102
--------------------------------------------------
category: chew
number of videos: 109
--------------------------------------------------
"""



extension = '.avi'
n_frames = 16
train_rate = 0.9

for i, c in enumerate(categories_list):
	p = os.path.join(org_path, c)
	videos = [v for v in os.listdir(p) if v.endswith(extension)]
	train_size = int(len(videos) * train_rate)
	for j, name in enumerate(videos):
		video_path = os.path.join(p, name)
		frames = get_frames(video_path, n_frames=n_frames)
		path2store = os.path.join(path, "train")
		if j >= train_size:
			path2store = os.path.join(path, "test")
		path2store = os.path.join(path2store, str(i)+"_"+name[:-4])
		print(path2store)
		os.makedirs(path2store, exist_ok=True)
		store_frames(frames, path2store)

        第一段代码输出五个类别的视频数量,可以看到 brush_hair、cartwheel、clap、catch 和  chew 依次有 107、107、130、102、109 个视频。最后一段代码的功能是依次对每个类别的每个视频抽帧,并将抽帧结果存置指定路径,同时划分训练集和测试集。这里设置每个视频的抽帧数量 n_frame=16,按 9:1(498:57) 划分训练集和测试集。每个样本(视频文件夹)名都在原来的名字前拼接上 ‘类别编号_’,其中类别编号为:

brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4

        这段代码的运行结果如图 2 所示(以测试集为例)即所有类别样本都在一个文件夹中,不再有类别目录,样本名字最前面的数字即为该样本的类别。

图2 测试集部分样本

4、train.py

4.1 导包

import os
import re
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torchvision.models import video
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

4.2 设置环境变量

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4.3 定义超参数

lr = 3e-5
gamma = 0.5
epochs = 20
step_size = 5
batch_size = 16
weight_decay = 1e-2

        这里定义初始学习率为 lr=3e-5,训练轮次为 epochs=20,batch_size=16,正则化系数为 weight_decay=1e-2。gamma 和 step_size 分为 torch.optim.lr_scheduler.ReduceLROnPlateau 类构造函数的入参 factor 和 patience。factor 是学习率降低的因子,新的学习率将是当前学习率乘以这个因子;patience 指观察验证指标在多少个 epoch 内没有改善后降低学习率。

4.4 定义图像变换函数

train_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(p=0.5),
    # 用于对图像进行随机的仿射变换, degrees 为旋转角度, translate 为水平和垂直平移的最大绝对分数
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170])
])

test_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170]),
])

        这里定义了两个图像变换函数,即用于训练集的 train_transform 和用于测试集的 test_transform, train_transform 在训练前依次对图片进行 resize 操作,以 0.5 的概率水平镜像变换操作,随机仿射操作(随机沿 x,y 方向分别平移 (-0.1*w,0.1*w)、(-0.1*h,0.1*h)),转换为 tensor 操作和标准化操作。test_transform 相较于 train_transform 去掉了起数据增强作用的两个操作。

4.5 定义训练集和测试集路径

# 训练集(498):测试机(57)=9:1
train_dir = 'hmdb5/train'
test_dir = 'hmdb5/test'

4.6 定义数据集类

class HMDB5Dataset(Dataset):
    def __init__(self, directory, transform):
        self.dir = directory
        self.transform = transform
        self.names = os.listdir(directory)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        path = os.path.join(self.dir, self.names[idx])
        frames = []
        for i in range(16):
            frame = Image.open(os.path.join(path, 'frame' + str(i) + '.jpg'))
            frames.append(self.transform(frame))

        frames = torch.stack(frames)
        # 返回 input 的转置版本, 即交换 input 的 dim0 和 dim1
        frames = torch.transpose(input=frames, dim0=0, dim1=1)
        # 编译正则表达式, ^ 表示匹配字符串的开始, + 表示一个或多个
        pattern = re.compile(r'^(\d+)_')
        match = re.search(pattern, self.names[idx])
        return frames, int(match.group(1))

        数据集类的构造函数定义了 3 个属性:dir(数据集路径)、transform(数据预处理方式)和names(样本名列表)。

        __getitem__ 函数根据 idx 按序取出一个样本的所有帧,并对所有帧执行了 transform 操作,最后返回的样本 frames 是 shape 为(channels,n_frames,h,w)的 tensor,该函数还利用 re 库从样本名中获取该样本的标签并返回。

4.7 定义模型

def init_model(mi):
    m = None
    if mi == 1:
        m = video.r3d_18(num_classes=5)  # epochs = 20, correct = 0.754

    return m.to(device)

        这里使用的模型为  torchvision.models.video.r3d_18[1],原文链接:https://arxiv.org/abs/1711.11248。实现可以参考 torch 源码。

4.8 计算评价指标

def correct_loss(data_loader, desc, test):
    results = []
    correct = 0.0
    test_loss = 0.0

    for img, tag in tqdm(data_loader, desc, total=len(data_loader)):
        img = img.to(device)
        tag = tag.to(device)
        pre = model(img)
        if test:
            test_loss += loss_fn(pre, tag)
        correct += torch.sum((pre.argmax(dim=1) == tag).float())

    results.append(correct / len(data_loader.dataset))
    if test:
        results.append(test_loss)

    return results

        correct_loss 函数用于计算 model 在 data_loader 上的 correct 和 loss(如果 test=True ,即data_loader 是测试集的数据加载器)。并将结果以列表的形式返回。

4.9 训练

if __name__ == '__main__':
    model = init_model(1)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=weight_decay)

    train_ds = HMDB5Dataset(train_dir, train_transform)
    test_ds = HMDB5Dataset(test_dir, test_transform)
    train_dl = DataLoader(train_ds, batch_size, True, num_workers=2)
    test_dl = DataLoader(test_ds, batch_size, False, num_workers=2)
    
    '''
     在验证指标停止改善时降低学习率:
         mode(str): 值域为 {'min', 'max'}。指定优化器应该监视的指标是应该最小化还是最大化
         factor(float): 学习率降低的因子。新的学习率将是当前学习率乘以这个因子
         patience(int): 观察验证指标在多少个 epoch 内没有改善后降低学习率
    '''
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=gamma, patience=step_size, verbose=True)
    best_loss = float('inf')
    for epoch in range(epochs):
        s_loss = 0.0
        print('Epoch:', epoch + 1, '/', epochs)
        for x, y in tqdm(train_dl, total=len(train_dl)):
            x = x.to(device)
            y = y.to(device)
            pred = model(x)
            loss = loss_fn(pred, y)
            s_loss += loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        model.eval()  # 将模型设置为评估模式
        with torch.no_grad():
            print("s_loss:%.3f" % s_loss)
            train_metrics = correct_loss(train_dl, 'compute train_metrics:', False)
            test_metrics = correct_loss(test_dl, 'compute test_metrics:', True)
            if test_metrics[1] < best_loss:
                best_loss = test_metrics[1]
            print("train_correct:%.3f,test_correct:%.3f" % (train_metrics[0], test_metrics[0]))

        model.train()
        scheduler.step(best_loss)

        这里使用交叉墒损失函数,AdamW 优化器,学习率使用 ReduceLROnPlateau scheduler,该 scheduler 监视的指标为 test loss。训练过程中得到的最高 test_correct=0.754。

5、项目目录结构

参考文献

[1] Du Tran, Heng Wang, Lorenzo Torresani, Jamie Ray, Yann LeCun, and Manohar Paluri. A closer look at spatiotemporal convolutions for action recognition. In CVPR, pages 6450–6459, 2018. 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1526255.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

疯狂 META:Aavegotchi 新一季稀有度挖矿来了!

经过数周的激烈讨论和参与&#xff0c;AavegotchiDAO 再次投票决定资助新一季的稀有度挖矿活动&#xff0c;这也是我们神奇的第八季&#xff01;朋友们&#xff0c;我们又开始啦——拿出你们最好的装备&#xff0c;擦亮那些可穿戴设备&#xff0c;准备好赚钱吧&#xff01; 与…

HarmonyOS系统开发ArkTS入门案例及组件

目录 一、声明式UI 二、ArkTs 快速入门案例 三、组件 四、渲染控制 一、声明式UI 声明式UI就是一种编写用户界面的范式或方式、 ArArkTS 在继承了Typescript语法的基础上&#xff0c;主要扩展了声明式UI开发相关的能力。 声明式UI开发范式大致流程&#xff1a;定义页面…

IPv4到IPv6的过渡策略

IPv4到IPv6的过渡是一个复杂且必要的过程&#xff0c;随着全球互联网的不断发展&#xff0c;IPv4地址资源的枯竭使得向IPv6过渡成为一项紧迫的任务。IPv6提供了更广阔的地址空间、更高的安全性和更灵活的路由方式&#xff0c;是未来互联网发展的必然趋势。下面将详细阐述如何从…

面向对象【内部类】

什么是内部类 将一个类 A 定义在另一个类 B 里面&#xff0c;里面的那个类 A 就称为内部类&#xff08;InnerClass&#xff09;&#xff0c;类 B 则称为外部类&#xff08;OuterClass&#xff09; 为什么要声明内部类 具体来说&#xff0c;当一个事物 A 的内部&#xff0c;还…

基于SpringBoot的后勤管理系统【附源码】

后勤管理系统开发说明 开发语言&#xff1a;Java 框架&#xff1a;ssm JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myecli…

LLM 面试知识点——模型基础知识

1、主流架构 目前LLM(Large Language Model)主流结构包括三种范式,分别为Encoder-Decoder、Causal Decoder、Prefix Decode。对应的网络整体结构和Attention掩码如下图。 、 各自特点、优缺点如下: 1)Encoder-Decoder 结构特点:输入双向注意力,输出单向注意力。 代表…

【C语言】linux内核pci_save_state

一、中文注释 //include\linux\pci.h /* 电源管理相关的例程 */ int pci_save_state(struct pci_dev *dev);//drivers\pci\pci.c /*** pci_save_state - 在挂起前保存PCI设备的配置空间* dev: - 我们正在处理的PCI设备*/ int pci_save_state(struct pci_dev *dev) {int i;/* X…

HTML + CSS 核心知识点- 定位

简述&#xff1a; 补充固定定位也会脱离文档流、不会占据原先位置 1、什么是文档流 文档流是指HTML文档中元素排列的规律和顺序。在网页中&#xff0c;元素按照其在HTML文档中出现的顺序依次排列&#xff0c;这种排列方式被称为文档流。文档流决定了元素在页面上的位置和互相之…

基于Spring Boot的美食分享系统设计与实现

摘 要 美食分享管理&#xff0c;其工作流程繁杂、多样、管理复杂与设备维护繁琐。而计算机已完全能够胜任美食分享管理工作&#xff0c;而且更加准确、方便、快捷、高效、清晰、透明&#xff0c;它完全可以克服以上所述的不足之处。这将给查询信息和管理带来很大的方便&#x…

PHP<=7.4.21 Development Server源码泄露漏洞 例题

打开题目 dirsearch扫描发现存在shell.php 非预期解 访问shell.php&#xff0c;往下翻直接就看到了flag.. 正常解法 访问shell.php 看见php的版本是7.3.33 我们知道 PHP<7.4.21时通过php -S开起的WEB服务器存在源码泄露漏洞&#xff0c;可以将PHP文件作为静态文件直接输…

万界星空科技WMS仓储管理包含哪些具体内容?

wms仓库管理是通过入库业务、出库业务、仓库调拨、库存调拨和虚仓管理等功能&#xff0c;综合批次管理、物料对应、库存盘点、质检管理、虚仓管理和即时库存管理等功能综合运用的管理系统&#xff0c;有效控制并跟踪仓库业务的物流和成本管理全过程&#xff0c;实现完善的企业仓…

面试笔记——Redis(缓存击穿、缓存雪崩)

缓存击穿 缓存击穿&#xff08;Cache Breakdown&#xff09;&#xff1a; 当某个缓存键的缓存失效时&#xff08;如&#xff0c;过期时间&#xff09;&#xff0c;同时有大量的请求到达&#xff0c;并且这些请求都需要获取相同的数据&#xff0c;这些请求会同时绕过缓存系统&a…

寻找可能认识的人

给一个命名为&#xff1a;friend.txt的文件 其中每一行中给出两个名字&#xff0c;中间用空格分开。&#xff08;下图为文件内容&#xff09; 题目&#xff1a;《查找出可能认识的人 》 代码如下&#xff1a; RelationMapper&#xff1a; package com.fesco.friend;import or…

C 练习实例77-指向指针的指针-二维数组

关于数组的一些操作 #include<stdio.h> #include<stdio.h> void fun(int b[],int length) {for(int i0;i<length;i){printf("%d ",b[i]);}printf("\n");for(int i0;i<length;i){ //数组作为形参传递&#xff0c;传递的是指针&#xff0…

做跨境用哪种代理IP比较好?

代理IP对于做跨境的小伙伴来说&#xff0c;都是必不可少的工具&#xff0c;目前出海的玩法已经是多种多样&#xff0c;开店、账号注册、短视频运营、直播带货、网站SEO等等都是跨境人需要涉及到的业务。而国外代理IP的获取渠道非常多&#xff0c;那么做跨境到底应该用哪种代理I…

onnx 格式模型可视化工具

onnx 格式模型可视化工具 0. 引言1. 可视化工具2. 安装 Netron: Viewer for ONNX models 0. 引言 ONNX 是一种开放格式&#xff0c;用于表示机器学习模型。ONNX 定义了一组通用运算符&#xff08;机器学习和深度学习模型的构建基块&#xff09;和通用文件格式&#xff0c;使 A…

R语言绘图 | 带标签的火火火火火火火山图 | 标记感兴趣基因 | 代码注释 + 结果解读

在火山图中&#xff0c;我们有时候会想要标注出自己感兴趣的基因&#xff0c;这个时候该怎么嘞&#xff01; 还有还有&#xff0c;在添加标签时&#xff0c;可能会遇到元素过多或位置密集导致标签显示不全&#xff0c;或者虽然显示全了但显得密集杂乱&#xff0c;不易阅读的情况…

6.计算机网络

重要章节、考题比重大&#xff01; 主要议题&#xff1a; 1.网络分类 偶尔考 局域网&#xff1a;覆盖面较小&#xff0c;吞吐效率高&#xff0c;传输速度快&#xff0c;可靠性高&#xff1b; 广域网&#xff1a;传输距离较远&#xff0c;通过分组交换技术来实现&#xff1b…

【图论】树链剖分

本篇博客参考&#xff1a; 【洛谷日报#17】树链剖分详解Oi Wiki 树链剖分 文章目录 基本概念代码实现常见应用路径维护&#xff1a;求树上两点路径权值和路径维护&#xff1a;改变两点最短路径上的所有点的权值求最近公共祖先 基本概念 首先&#xff0c;树链剖分是什么呢&…

简单使用NSIS打包软件

NSIS是一个开源的打包工具. 官网: Download - NSIS (sourceforge.io) 使用这个编译 ​ 但是不建议使用这玩意写脚本,字体太难看了.我用vscode写的脚本,用这个编译的. ​ 写好脚本用这个软件打开, 然后选择这个编译,如果语法有错误 会编译不过,会提醒你哪一行不行,如果编译…