Faster RCNN网络源码解读(Ⅱ) --- Faster RCNN源码使用

news2024/10/1 15:27:29

目录

一、源码链接

二、环境配置

三、文件结构 

四、预训练权重下载地址 

五、训练集 

六、训练方法及注意事项 

七、大概看一下训练过程(train_mobilenetv2.py) 


一、源码链接

Faster R-CNN源码链接icon-default.png?t=MBR7https://pan.baidu.com/s/1SQjyLXD47H11ke05OXYSsQ?pwd=gt4f 

二、环境配置

* Python3.6/3.7/3.8
* Pytorch1.7.1(注意:必须是1.6.0或以上,因为使用官方提供的混合精度训练1.6.0后才支持)
* pycocotools(Linux:`pip install pycocotools`; Windows:`pip install pycocotools-windows`(不需要额外安装vs))
* Ubuntu或Centos(不建议Windows)
* 最好使用GPU训练
* 详细环境配置见`requirements.txt`

三、文件结构 

  ├── backbone: 特征提取网络,可以根据自己的要求选择
  ├── network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块)
  ├── train_utils: 训练验证相关模块(包括cocotools)
  ├── my_dataset.py: 自定义dataset用于读取VOC数据集
  ├── train_mobilenet.py: 以MobileNetV2做为backbone进行训练
  ├── train_resnet50_fpn.py: 以resnet50+FPN做为backbone进行训练
  ├── train_multi_GPU.py: 针对使用多GPU的用户使用
  ├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
  ├── validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
  └── pascal_voc_classes.json: pascal_voc标签文件

四、预训练权重下载地址 

## 预训练权重下载地址(下载后放入backbone文件夹中):
* MobileNetV2 weights(下载后重命名为`mobilenet_v2.pth`,然后放到`bakcbone`文件夹下): https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
* Resnet50 weights(下载后重命名为`resnet50.pth`,然后放到`bakcbone`文件夹下): https://download.pytorch.org/models/resnet50-0676ba61.pth
* ResNet50+FPN weights: https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
* 注意,下载的预训练权重记得要重命名,比如在train_resnet50_fpn.py中读取的是`fasterrcnn_resnet50_fpn_coco.pth`文件,
  不是`fasterrcnn_resnet50_fpn_coco-258fb6c6.pth`,然后放到当前项目根目录下即可。

五、训练集 

## 数据集,本例程使用的是PASCAL VOC2012数据集
* Pascal VOC2012 train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
* 如果不了解数据集或者想使用自己的数据集进行训练,请参考我的bilibili:https://b23.tv/F1kSCK
* 使用ResNet50+FPN以及迁移学习在VOC2012数据集上得到的权重: 链接:https://pan.baidu.com/s/1ifilndFRtAV5RDZINSHj5w 提取码:dsz8

六、训练方法及注意事项 

## 训练方法
* 确保提前准备好数据集
* 确保提前下载好对应预训练模型权重
* 若要训练mobilenetv2+fasterrcnn,直接使用train_mobilenet.py训练脚本
* 若要训练resnet50+fpn+fasterrcnn,直接使用train_resnet50_fpn.py训练脚本
* 若要使用多GPU训练,使用`python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
* `CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py`

## 注意事项
* 在使用训练脚本时,注意要将`--data-path`(VOC_root)设置为自己存放`VOCdevkit`文件夹所在的**根目录**
* 由于带有FPN结构的Faster RCNN很吃显存,如果GPU的显存不够(如果batch_size小于8的话)建议在create_model函数中使用默认的norm_layer,
  即不传递norm_layer变量,默认去使用FrozenBatchNorm2d(即不会去更新参数的bn层),使用中发现效果也很好。
* 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
* 在使用预测脚本时,要将`train_weights`设置为你自己生成的权重路径。
* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改`--num-classes`、`--data-path`和`--weights-path`即可,其他代码尽量不要改动

七、大概看一下训练过程(train_mobilenetv2.py) 

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using {} device training.".format(device.type))

    # 用来保存coco_info的文件
    results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

    # 检查保存权重文件夹是否存在,不存在则创建
    if not os.path.exists("save_weights"):
        os.makedirs("save_weights")

    data_transform = {
        "train": transforms.Compose([transforms.ToTensor(),
                                     transforms.RandomHorizontalFlip(0.5)]),
        "val": transforms.Compose([transforms.ToTensor()])
    }

    VOC_root = "./"  # VOCdevkit
    aspect_ratio_group_factor = 3
    batch_size = 8
    amp = False  # 是否使用混合精度训练,需要GPU支持

    # check voc root
    if os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:
        raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))

    # load train data set
    # VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
    train_dataset = VOCDataSet(VOC_root, "2012", data_transform["train"], "train.txt")
    train_sampler = None

    # 是否按图片相似高宽比采样图片组成batch
    # 使用的话能够减小训练时所需GPU显存,默认使用
    if aspect_ratio_group_factor >= 0:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)
        # 统计所有图像高宽比例在bins区间中的位置索引
        group_ids = create_aspect_ratio_groups(train_dataset, k=aspect_ratio_group_factor)
        # 每个batch图片从同一高宽比例区间中取
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, batch_size)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using %g dataloader workers' % nw)

    # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
    if train_sampler:
        # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_sampler=train_batch_sampler,
                                                        pin_memory=True,
                                                        num_workers=nw,
                                                        collate_fn=train_dataset.collate_fn)
    else:
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=batch_size,
                                                        shuffle=True,
                                                        pin_memory=True,
                                                        num_workers=nw,
                                                        collate_fn=train_dataset.collate_fn)

    # load validation data set
    # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
    val_dataset = VOCDataSet(VOC_root, "2012", data_transform["val"], "val.txt")
    val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  num_workers=nw,
                                                  collate_fn=val_dataset.collate_fn)

    # create model num_classes equal background + 20 classes
    model = create_model(num_classes=21)
    # print(model)

    model.to(device)

    scaler = torch.cuda.amp.GradScaler() if amp else None

    train_loss = []
    learning_rate = []
    val_map = []

    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
    #  first frozen backbone and train 5 epochs                   #
    #  首先冻结前置特征提取网络权重(backbone),训练rpn以及最终预测网络部分 #
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
    for param in model.backbone.parameters():
        param.requires_grad = False

    # define optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)

    init_epochs = 5
    for epoch in range(init_epochs):
        # train for one epoch, printing every 10 iterations
        mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
                                              device, epoch, print_freq=50,
                                              warmup=True, scaler=scaler)
        train_loss.append(mean_loss.item())
        learning_rate.append(lr)

        # evaluate on the test dataset
        coco_info = utils.evaluate(model, val_data_loader, device=device)

        # write into txt
        with open(results_file, "a") as f:
            # 写入的数据包括coco指标还有loss和learning rate
            result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
            txt = "epoch:{} {}".format(epoch, '  '.join(result_info))
            f.write(txt + "\n")

        val_map.append(coco_info[1])  # pascal mAP

    torch.save(model.state_dict(), "./save_weights/pretrain.pth")

    # # # # # # # # # # # # # # # # # # # # # # # # # # # #
    #  second unfrozen backbone and train all network     #
    #  解冻前置特征提取网络权重(backbone),接着训练整个网络权重  #
    # # # # # # # # # # # # # # # # # # # # # # # # # # # #

    # 冻结backbone部分底层权重
    for name, parameter in model.backbone.named_parameters():
        split_name = name.split(".")[0]
        if split_name in ["0", "1", "2", "3"]:
            parameter.requires_grad = False
        else:
            parameter.requires_grad = True

    # define optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)
    # learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.33)
    num_epochs = 20
    for epoch in range(init_epochs, num_epochs+init_epochs, 1):
        # train for one epoch, printing every 50 iterations
        mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
                                              device, epoch, print_freq=50,
                                              warmup=True, scaler=scaler)
        train_loss.append(mean_loss.item())
        learning_rate.append(lr)

        # update the learning rate
        lr_scheduler.step()

        # evaluate on the test dataset
        coco_info = utils.evaluate(model, val_data_loader, device=device)

        # write into txt
        with open(results_file, "a") as f:
            # 写入的数据包括coco指标还有loss和learning rate
            result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
            txt = "epoch:{} {}".format(epoch, '  '.join(result_info))
            f.write(txt + "\n")

        val_map.append(coco_info[1])  # pascal mAP

        # save weights
        # 仅保存最后5个epoch的权重
        if epoch in range(num_epochs+init_epochs)[-5:]:
            save_files = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch}
            torch.save(save_files, "./save_weights/mobile-model-{}.pth".format(epoch))

    # plot loss and lr curve
    if len(train_loss) != 0 and len(learning_rate) != 0:
        from plot_curve import plot_loss_and_lr
        plot_loss_and_lr(train_loss, learning_rate)

    # plot mAP curve
    if len(val_map) != 0:
        from plot_curve import plot_map
        plot_map(val_map)

        先判断是否有可用的GPU,若没有则使用CPU进行训练。

        data_transform定义了图像预处理的函数。

        VOC_root指定了VOC数据集的根目录。

        VOCDataSet定义我们的数据集(通过mydataset.py文件定义我们自己的数据集)

        DataLoader进行数据载入。

        训练过程首先冻结前置特征提取网络权重(backbone),训练rpn以及最终预测网络部分。随后解冻前置特征提取网络权重(backbone),接着训练整个网络权重。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/128192.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于javaweb的学院社团管理系统(idea+servlet+jsp)

一、系统简介 本项目采用idea工具开发,jspservletjquery技术编写,数据库采用的是mysql,navicat开发工具。 系统一共分为3个角色分别是:管理员,学生,社长 获取方式:基于javaweb的学院社团管理系…

VisualStudio2015安装不上的解决方法_选择对应msi_依然报错继续选择---VisualStudio2015工作笔记001

这个visualstudio网上竟然有卖的,真的是太~咱啥也不说了~ 好了说解决办法吧; 1.这里完整版大小挺大的6G多,但是这里我们找的有可能是只有3.8G左右的,这个版本的就是有文件损坏...但是这个版本是可以装上的,也不耽误用的,我亲自测试了. 2.因为很久没写vb.net程序了,想写个小程…

Docker中安装宝塔

1、docker拉取ubuntu系统 docker pull ubuntudocker pull ubuntu 2、运行容器 docker run -i -t -d --name bt -p 2000:20 -p 2100:21 -p 8000:80 -p 4430:443 -p 8880:888 -p 8888:8888 --privilegedtrue -v /d/baota:/www/wwwroot ubuntu -v后的/d/baota代表本地D盘的bao…

【LeetCode】二叉树中的最大路径和 [H](递归)

124. 二叉树中的最大路径和 - 力扣(LeetCode) 一、题目 路径 被定义为一条从树中任意节点出发,沿父节点-子节点连接,达到任意节点的序列。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一…

擎创技术流 | ClickHouse实用工具—ckman教程(8)

2022年即将迎来终点,各大公司和网站开始出“年终使用报告”了,而关于ClickHouse系列分享也来到了尾期。本系列最初一共规划了10期内容,一转眼就只剩最后2期了,初衷就是单纯的技术分享,如果你能从中得到启发&#xff0c…

圣诞节来临,TikTok上圣诞相关产品销量突破5万单丨超店有数

12月25日,圣诞节正式来临。TikTok卖家和达人如何把握住这个一年一度的「西方春节」,如何实现精准选品,快速爆单。凡事预则立,不预则废。快人一步,掌握趋势,抢占先机。 从谷歌趋势观察,一周内「…

图说Netty服务端启动过程

我们知道Netty是一个基于JDK的nio实现的网络编程框架,那Netty的服务端是怎么启动的呢,包括他是何时register 的,何时 bind 端口的,以及何时开始读取网络中的数据的? 让我们带着这个疑问,通过一个官方的例子…

webRTC 实现人脸识别

首先我们需要先了解一下什么是webRTC 他能做什么 webRTC主要是帮我们处理多媒体应用,如音视频通话,屏幕共享都可以实现,主要基于浏览器API调用,其底层浏览器会调用native C 等一些库帮我们实现的,而我们在应用层掉API …

欧科云链荣获人民网匠心技术奖,科技创新共造企业发展“强引擎”

什么是匠心精神? 是简单的事情重复做 重复的事情用心做 是对自己热爱的事物 不断坚持、不断钻研的精神 昨日,由人民网主办的“2022人民财经高峰论坛”成功举办,论坛还公布了“第十九届人民匠心奖”获奖名单。欧科云链携旗下“链上天眼”产品荣…

第三十七章 数论——博弈论

第三十七章 数论——博弈论一、Nim游戏1、题目2、结论3、结论验证4、代码二、台阶——Nim游戏1、问题2、思路2、代码三、集合——Nim游戏1、问题2、思路—SG()函数2、代码实现(记忆化搜索)一、Nim游戏 1、题目 2、结论 这里直接说结论: 假设…

vue导入私有组件和注册全局组件

目录先下载并配置插件导入私有组件注册全局组件先下载并配置插件 导入的时候需要路径,有个符号,但不能提示路径,需要手打路径,会发现很麻烦,这时候可以通过vscode插件来解决 vscode搜索Path Autocomplete 配置插件,点击插件设置—扩展设置,点开任意一个setting.json中编辑,打开…

Yield Guild Games 和 Axie Infinity:迄今为止的旅程

2022 年,菲律宾被认为是全球应用 web3 的中心,在 Chainalysis 的全球加密货币应用指数中排名第二,在拥有最多 MetaMask 用户的国家中排名第三。虽然该国作为 Coins.ph 和 BloomX 等开创性数字资产交易所以及对加密货币友好的 UnionBank 的所在…

Flutter生命周期

一、组件生命周期 flutter组件只有两种:有状态和无状态组件。由于无状态组件效率高,如果不涉及到组件内部的数据存储,尽量多的使用无状态组件 1、StatelessWidget build:组件渲染 调用次数:1次 StatelessWidget是无状…

MYSQL索引和sql优化

基本 索引是帮助数据高效查询的有序数据结构,没有索引进行查询就会进行全表扫描myisam中的.MYI和innodb中的.idb都是存放索引的文件。索引提高查询效率的同时,也降低了更新表的数据,因为数据库中删改查会维护索引的结构。一般提到的索引就是…

安装mysqlclient失败解决办法

简介 系统:MAC 前因:django使用mysql数据库报错django.core.exceptions.ImproperlyConfigured: Error loading MySQLdb module.Django使用MySQL数据库需要加载 MySQLdb模块,需要安装 mysqlclient(django2.2以前安装pymysql&#…

ipsec 建立正常后业务端口测试通过但是业务访问故障

某公司ipsec 分支和总部对接成功后,检查发现两端业务测试正常。分支和总部服务器可以互访,分支测试总部服务器80端口开放正常,排除两侧因策略的影响。但是总部服务器web业务和数据库业务均出现无法访问,web页面打不开,…

Es初步检索命令

1、_cat GET /_cat/nodes:查看所有节点 请求 : http://192.168.107.129:9200/_cat/nodes 响应 : 127.0.0.1 15 95 8 0.19 0.16 0.24 dilm * 32bb46713f1b GET /_cat/health:查看 es 健康状况 请求 : http://192.16…

kali - 通过抓包获取FTP连接密码

数据来源 开始实验 实验目标:在win2003主机部署一台FTP服务器,然后winXP通过账户密码连接2003主机的FTP服务器,kali虚拟机抓取数据包获取密码。 实验拓补图 实验流程: 1、开启虚拟机并配置IP 我这里开了一台Kali、一台window…

elmentUI组建中el-date-picker实现限制时间范围精确到小时

elmentUI组建中el-date-picker实现限制时间范围精确到小时 需求要求 时间选择器只能选择今天之前的日期.默认时间是前一天00点~23点后台返回的最小时间和最大时间时间精度限制到小时 开始想着用type"datetimerange"来实现,后来发现控制时间禁用无法实现,后改变思路…

SkyEye助力飞控软件Debug

​01.Debug是什么? 1947年9月9日,美国著名科学家格蕾丝.霍普(Grace Hopper)与其同伴在对Mark II计算机进行研究时发现,导致计算机无法正常工作的罪魁祸首居然是一只粘在继电器上的小飞蛾。格蕾丝用镊子将飞蛾夹出&…