FasterRCNN训练自己的数据集

news2024/10/6 18:36:15

2016年提出的Faster RCNN目标检测模型是深度学习现代目标检测算法的开山之作,也是第一个真正全流程都是神经网络的目标检测模型。

其主要步骤如下:

1,使用CNN对输入图片提取feature map.

2,对feature map上的每个点设计一套不同大小和长宽比的anchor作为先验框。

3,设计RPN网络从大量的anchor中筛选出一些作为目标框的proposals并用回归分支纠正它们的位置。

4,使用ROI Pooling技术对不同大小的proposals获取相同大小的对应特征图,以便后续分类模型一并处理。

5,在proposals的feature map上使用分类分支和回归分支进一步预测目标类别和更精确的定位。

anchor技巧ROI Pooling技术 是非常值得学习的技巧,在许多目标检测模型中都能看到他们的身影。

b1c2e63ddc895a32120c5b0ddb03a614.jpeg

尽管FasterRCNN历史悠久,但依然是一个非常重要的目标检测任务的baseline.

一般会把它叫做two-stage的目标检测模型,主要是如果train from scratch,   RPN网络提取proposals和后续对propasals的定位分类 这两个步骤是要分开训练的,但在微调的时候,通常可以一起训练。

本文我们主要演示调用torchvision中的faster-rcnn模型在自己的数据集上微调来检测螺丝螺母。

#!pip install torchvision,torchkeras
import numpy as np
import pandas as pd 
from matplotlib import pyplot as plt
from PIL import Image,ImageColor,ImageDraw,ImageFont 

import torch
from torch import nn
import torchvision
from torchvision import datasets, models, transforms

import datetime
import os
import copy
import json 

print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2.0.0+cu117

〇,预训练模型

from torchkeras.data import get_example_image
img = get_example_image('park.jpg')
img.save('park.jpg')
from torchkeras.plots import vis_detection 

# 准备数据
inputs = []
img = Image.open('park.jpg').convert("RGB")
img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()
inputs.append(img_tensor)    

# 加载模型
num_classes = 91
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1,
    num_classes = num_classes)

if torch.cuda.is_available():
    model.to("cuda:0")
model.eval()

# 预测结果
with torch.no_grad():
    predictions = model(inputs)


# 结果可视化
class_names = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1.meta['categories']

vis_detection(img,predictions[0],class_names,min_score = 0.8)

ed52a62468ca74dbcbd773733d61e7c5.png

下面代码我们演示使用我开发的优雅的torchkeras工具在自己的数据集上对Faster-RCNN模型进行finetune。

我们使用一个非常简单的螺丝(bolt)螺母(nut)数据集作为示范。

公众号 算法美食屋 后台回复关键词:torchkeras,获取本文notebook代码和 bolt nut 数据集 下载地址。

一,准备数据

data_path = "./data/bolt_nut"

train_images_path = "./data/bolt_nut/train"
train_targets_path = './data/bolt_nut/train.txt'

val_images_path = "./data/bolt_nut/val"
val_targets_path = './data/bolt_nut/val.txt'

class_names = ['__background__','bolt','nut']
class BoltNut(torch.utils.data.Dataset):
    def __init__(self, images_path, targets_path, 
                 class_names = class_names,
                 transforms = None
                ):
        self.images_path = images_path
        self.targets_path = targets_path
        self.transforms = transforms
        self.infos_list = open(targets_path,"r").readlines()
        self.class_names = class_names

    def __getitem__(self, idx):
        
        info_str = self.infos_list[idx]
        info_arr = info_str.replace("\n","").replace("\t ","").split("\t")
        
        img_path = info_arr.pop(0)
        
        info_arr = [x for x in info_arr if x.strip()] 
        infos = [json.loads(x) for x in info_arr]

        img= Image.open(os.path.join(self.images_path,img_path)).convert("RGB")

        target = {}
        target["image_id"] = torch.tensor([int(img_path.split(".")[0])],dtype = torch.int64)  
        target["labels"] = torch.tensor([self.class_names.index(x["value"]) for x in infos],
                                        dtype = torch.int64)

        coords = [x["coordinate"]  for x in infos]
        boxes = torch.tensor([[xmin,ymin,xmax,ymax] for (xmin,ymin), (xmax,ymax)  in coords])
        target["boxes"] = boxes

        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        target["iscrowd"] = torch.zeros((len(infos) ,), dtype=torch.int64)
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
  
        return img, target

    def __len__(self):
        return len(self.infos_list)
# 可视化数据集
ds_train = BoltNut(train_images_path,train_targets_path)
img,target = ds_train[12]

target["scores"] = torch.ones_like(target["labels"])
img_result = vis_detection(img,target,class_names,min_score = 0.8)
img_result

0f773c51efefd49c9ef239c74d8c1302.png

下面我们设计数据增强模块

import random 
from torchvision import transforms as T

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
        return image, target


class ToTensor(object):
    def __call__(self, image, target):
        image = T.ToTensor()(image)
        return image, target
transforms_train = Compose([ToTensor(),RandomHorizontalFlip(0.5)])
transforms_val = ToTensor()

ds_train = BoltNut(train_images_path,train_targets_path,transforms=transforms_train)
ds_val = BoltNut(val_images_path,val_targets_path,transforms=transforms_val)
def collate_fn(batch):
      return tuple(zip(*batch))

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, 
          shuffle=True, num_workers=4,collate_fn= collate_fn)

dl_val = torch.utils.data.DataLoader(ds_val, batch_size=2, 
          shuffle=True, num_workers=4,collate_fn= collate_fn)
for batch in dl_train:
    features,labels = batch  
    break

二,定义模型

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

num_classes = 3  # 3 classes (bult,nut) + background
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

三,训练模型

from torchkeras import KerasModel
class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.train() #attention here
    
    def __call__(self, batch):
        features,labels = batch 
        
        #loss
        loss_dict = self.net(features,labels)
        loss = sum(loss_dict.values())
        
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        #all_preds = self.accelerator.gather(preds)
        #all_labels = self.accelerator.gather(labels)
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                             momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=4)

keras_model = KerasModel(model,
                         loss_fn = None,
                         metrics_dict=None,
                         optimizer= optimizer,
                         lr_scheduler=lr_scheduler
                        )

keras_model.fit(train_data=dl_train,val_data=dl_val,
    epochs=20,patience=5,
    monitor='val_loss',
    mode='min',
    ckpt_path ='faster-rcnn.pt',
    plot=True
)

d299f12f273b84c75b01062c0740f2b0.png

024eaa1bd0ac0807dfeea0bdc3963d7c.png

四,评估模型

import torch 

from PIL import Image 
from tqdm import tqdm
from ultralytics.yolo.utils import set_logging
set_logging(verbose=False)
from ultralytics.yolo.utils.metrics import  DetMetrics, box_iou
def process_batch(predictions, targets, 
                  iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
                 ):
    ...
    return metrics
model.eval()
list_predictions = [model(x[0].to('cuda')[None,...])[0] for x in ds_val]
list_targets = [x[1] for x in ds_val]

names = {0:'bolt',1:'nut'}
metrics = eval_metrics(list_predictions,
                       list_targets,
                       names =  names)
display(metrics.results_dict)
{'metrics/precision(B)': 0.9976781395819151,
 'metrics/recall(B)': 1.0,
 'metrics/mAP50(B)': 0.995,
 'metrics/mAP50-95(B)': 0.8542317510036526,
 'fitness': 0.8683085759032874}
import pandas as pd 
df = pd.DataFrame()
df['metric'] = metrics.keys
for i,c in names.items():
    df[c] = metrics.class_result(i)
df

9dfe2e3a0af99c86ab178b6e85870002.png

五,使用模型

# 准备数据
inputs = []
img_path = os.path.join(val_images_path,os.listdir(val_images_path)[5])
img = Image.open(img_path).convert("RGB")
img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()
inputs.append(img_tensor)    

model.eval()

# 预测结果
with torch.no_grad():
    predictions = model(inputs)

# 结果可视化
vis_detection(img,predictions[0],list(idx2names.values()),min_score = 0.8)

38fed89e054034ac169df0251abc66d5.png

公众号 算法美食屋 后台回复关键词:torchkeras,获取本文notebook代码和 bolt nut 数据集 下载地址。

万水千山总是情,点个赞赞行不行?😋😋

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/553320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Roboflow的使用

文章目录 前言一、使用labelimg标注数据集二、导入roboflow1.注册roboflow账户2.导入图片2.1 创建工作区workspace&#xff08;非必须&#xff09;2.2 创建项目 project2.3 导入 3、导出图片4、同一个数据集可以导出不同类型 前言 我自己也是一个小白不是很会&#xff0c;如果…

ASO优化之怎么做好关键词本地化覆盖

如果想要我们的应用走向国际化&#xff0c;被多个国家/地区使用&#xff0c;那么做好关键词本地化覆盖至关重要。我们可以主要针对中文和英文进行设置&#xff08;准备两套元数据&#xff09;&#xff0c;这样能够迅速增加应用商店ASO关键词覆盖数量。 那么我们要在哪里设置&a…

小白也能懂的薛斯通道抄底指标以及公式(附源码)

什么是薛斯通道&#xff1f; 上个世纪70年代&#xff0c;美国人薛斯最早发明了薛斯通道。 他本人曾是研究火箭运行的。 薛斯通道包括两组通道指标&#xff0c;分别是长期大通道指标&#xff08;100天&#xff09;和短期小通道指标&#xff08;10天&#xff09;。 股价实际上是被…

Netflix 团队解决了 Linux 内核中的 FUSE 死锁

Laf 公众号已接入了 AI 绘画工具 Midjourney&#xff0c;可以让你轻松画出很多“大师”级的作品。同时还接入了 AI 聊天机器人&#xff0c;支持 GPT、Claude 以及 Laf 专有模型&#xff0c;可通过指令来随意切换模型。欢迎前来调戏&#x1f447; <<< 左右滑动见更多 &…

Go与神经网络:张量运算

0. 背景 2023年年初&#xff0c;我们很可能是见证了一次新工业革命的起点&#xff0c;也可能是见证了AGI(Artificial general intelligence&#xff0c;通用人工智能)[1]孕育的开始。ChatGPT应用以及后续GPT-4大模型的出现&#xff0c;其震撼程度远超当年AlphaGo战胜人类顶尖围…

微信小程序-页面跳转wxAPI

官方文档地址&#xff1a;https://developers.weixin.qq.com/miniprogram/dev/api/route/wx.navigateTo.html wx.navigateTo(Object object) 更改首页代码&#xff0c;添加一个按钮&#xff0c;绑定一个事件的点击&#xff1a; <!--index.wxml--> <text>首页</t…

《前端》HTML常用标签

文章目录 HTML导读HTML格式常用标签标题标签段落标签格式化标签超链接标签标签的几种形式 表格标签列表标签表单标签按钮标签无语义标签 ​&#x1f451;作者主页&#xff1a;Java冰激凌 &#x1f4d6;专栏链接&#xff1a;前端 HTML导读 html是超文本标记语言 一般直接运行在…

33从零开始学Java之方法的递归调用到底是怎么回事?

作者&#xff1a;孙玉昌&#xff0c;昵称【一一哥】&#xff0c;另外【壹壹哥】也是我哦 千锋教育高级教研员、CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者 前言 在之前的文章中&#xff0c;壹哥给大家讲解了方法的定义、调用及参数、返回值等内容&#xff0c;接下…

广告行业中那些趣事系列62:keybert在实际业务中的使用分享

导读&#xff1a;本文是“数据拾光者”专栏的第六十二篇文章&#xff0c;这个系列将介绍在广告行业中自然语言处理和推荐系统实践。本篇作为之前keybert的补充主要介绍了keybert在实际业务中的使用分享&#xff0c;对于希望在实际业务场景中使用keybert的小伙伴可能有帮助。 欢…

微信小程序-页面生命周期方法

在经过上一篇文章的介绍之后&#xff0c;我们知道了大体的生命周期在什么时候执行&#xff0c;这次主要是以代码的形式来展示一下具体的阶段执行什么生命周期方法。 首先我们编写一个代码可以从首页跳转到日志页面&#xff1a; <!--index.wxml--> <text>首页</t…

项目中excel表格中由合同内容--转换为验收清单的办法(python操作excel表格)

需求&#xff1a; 把合同内容--转换为验收清单的办法&#xff08;python操作excel表格&#xff09; 1.字段重新排序 2.选择需要的表格列 原始的表格内容&#xff1a; 需要的格式&#xff1a; 涉及的技术点&#xff1a; 1.读取原始表格“readexcel1.xlsx”内容&#xff0c;修改…

第十一章 Productions最佳实践 - 生产电子表格

文章目录 第十一章 Productions最佳实践 - 生产电子表格生产电子表格界面设计 第十一章 Productions最佳实践 - 生产电子表格 生产电子表格 维护一个电子表格是很有帮助的&#xff0c;它可以逐个应用程序地组织信息系统。作为一般准则&#xff0c;应该为每个提供传入或传出数…

# 性能诊断 JProfiler 工具使用

性能诊断 JProfiler 工具使用 JProfiler是一个重量级的JVM监控工具&#xff0c;提供对JVM精确监控&#xff0c;其中堆遍历、CPU剖析、线程剖析看成定位当前系统瓶颈的得力工具。可以统计压测过程中JVM的监控数据&#xff0c;定位性能问题。 官网地址&#xff1a;Java Profiler…

初识linux之网络基础概念

目录 一、网络发展 1. 独立模式 2. 网络互联 二、认识协议 1. 为什么要有协议 2. 什么是协议 三、网络协议初识 1. 协议分层 2. 协议分层的优点 3. 理解分层 4. OSI七层模型 4.1 概念 4.2 模型形式 4.3 各层的作用 5. TCP/IP五层&#xff08;或四层&#xff09…

书评 | 《深入理解高并发编程:JDK核心技术》

书评 | 《深入理解高并发编程&#xff1a;JDK核心技术》 作者简介 冰河&#xff1a;互联网资深技术专家、数据库技术专家、分布式与微服务架构专家&#xff1b;多年来一直致力于分布式系统架构、微服务、分布式数据库、分布式事务与大数据技术的研究&#xff0c;在高并发、高可…

MySQL高级篇——关联查询和子查询优化

导航&#xff1a; 【黑马Java笔记踩坑汇总】Java基础进阶JavaWebSSMSpringBoot瑞吉外卖SpringCloud黑马旅游谷粒商城学成在线设计模式牛客面试题 目录 1. 关联查询优化 1.0 优化方案 1.1 数据准备 1.2 左外连接&#xff1a;优先右表创建索引&#xff0c;连接字段类型要一致…

numpy-stl实战3D建模【Python】

想象一下&#xff0c;我们需要用 python 编程语言构建某个物体的三维模型&#xff0c;然后将其可视化&#xff0c;或者准备一个文件以便在 3D 打印机上打印。 有几个库可以解决这些问题。 让我们来看看&#xff0c;如何在 Python 中从点、边和图元构建 3D 模型。 如何执行基本的…

如何对图片进行卷积计算

1 问题 如何对图片进行卷积计算&#xff1f; 2 方法 先导入torch和torch里的nn类&#xff0c;然后设置一个指定尺寸的随机像素值的图片&#xff0c;然后使用nn.conv2d函数进行卷积计算&#xff0c;然后建立全连接层&#xff0c;最后得到新的图片的尺寸 步骤: (1) 导入实验所需要…

CyberLink的音频编辑软件AudioDirector Ultra 13.4版本在win10系统的下载与安装配置教程

目录 前言一、AudioDirector Ultra安装二、使用配置总结 前言 AudioDirector Ultra是由CyberLink公司开发的一款强大的音频编辑工具&#xff0c;旨在为用户提供全面的音频后期制作和编辑解决方案。该软件支持多种音频格式&#xff0c;包括MP3、WAV、M4A等&#xff0c;并且可以…

网络工程师精选习题详解(二)

请点击↑关注、收藏&#xff0c;本博客免费为你获取精彩知识分享&#xff01;有惊喜哟&#xff01;&#xff01; 201.通常使用&#xff08;&#xff09;为IP数据报进行加密。 A.IPSec B.PP2P C.HTTPS D.TLS 答案&#xff1a;A IP Sec可以为IP数据报进行加密。 …