第99步 深度学习图像目标检测:SSDlite建模

news2024/11/17 15:30:06

基于WIN10的64位系统演示

一、写在前面

本期,我们继续学习深度学习图像目标检测系列,SSD(Single Shot MultiBox Detector)模型的后续版本,SSDlite模型。

二、SSDlite简介

SSDLite 是 SSD 模型的一个变种,旨在为移动设备和边缘计算设备提供更高效的目标检测。SSDLite 的主要特点是使用了轻量级的骨干网络和特定的卷积操作来减少计算复杂性,从而提高检测速度,同时在大多数情况下仍保持了较高的准确性。

以下是 SSDLite 的主要特性和组件:

(1)轻量级骨干:

SSDLite 不使用 VGG 或 ResNet 这样的重量级骨干。相反,它使用 MobileNet 作为骨干,特别是 MobileNetV2 或 MobileNetV3。这些网络使用深度可分离的卷积和其他轻量级操作来减少计算成本。

(2)深度可分离的卷积:

这是 MobileNet 的核心组件,也被用于 SSDLite。深度可分离的卷积将传统的卷积操作分解为两个较小的操作:一个深度卷积和一个点卷积,这大大减少了计算和参数数量。

(3)多尺度特征映射:

与原始的 SSD 相似,SSDLite 也从不同的层级提取特征图以检测不同大小的物体。

(4)默认框:

SSDLite 也使用默认框(或称为锚框)来进行边界框预测。

(5)单阶段检测:

与 SSD 相同,SSDLite 也是一个单阶段检测器,同时进行边界框回归和分类。

(6)损失函数:

SSDLite 使用与 SSD 相同的组合损失,包括平滑 L1 损失和交叉熵损失。

综上,SSDLite 是为了速度和效率而设计的,特别是针对计算和内存资源有限的设备。通过使用轻量级的骨干和深度可分离的卷积,它能够在减少计算负担的同时,仍然保持合理的检测准确性。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、SSDlite实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

# Function to parse XML annotations
def parse_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    boxes = []
    for obj in root.findall("object"):
        bndbox = obj.find("bndbox")
        xmin = int(bndbox.find("xmin").text)
        ymin = int(bndbox.find("ymin").text)
        xmax = int(bndbox.find("xmax").text)
        ymax = int(bndbox.find("ymax").text)

        # Check if the bounding box is valid
        if xmin < xmax and ymin < ymax:
            boxes.append((xmin, ymin, xmax, ymax))
        else:
            print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")

    return boxes

# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):
    all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]
    random.shuffle(all_images)
    split_idx = int(len(all_images) * split_ratio)
    train_images = all_images[:split_idx]
    val_images = all_images[split_idx:]
    
    return train_images, val_images


# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, annotation_dir, image_list, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.image_list = image_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_list[idx])
        image = Image.open(image_path).convert("RGB")
        
        xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))
        boxes = parse_xml(xml_path)
        
        # Check for empty bounding boxes and return None
        if len(boxes) == 0:
            return None
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((len(boxes),), dtype=torch.int64)
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["iscrowd"] = iscrowd
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
    
        return image, target

# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # Convert PIL image to tensor
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])


# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return tuple(zip(*batch))


def get_ssdlite_model_for_finetuning(num_classes):
    # Load an SSDlite model with a MobileNetV3 Large backbone without pre-trained weights
    model = ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)
    return model

# Function to save the model
def save_model(model, path="SSDlite_mtb.pth", save_full_model=False):
    if save_full_model:
        torch.save(model, path)
    else:
        torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# Function to compute Intersection over Union
def compute_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        # Return placeholder batch if entirely empty
        return [torch.zeros(1, 3, 224, 224)], [{}]
    return tuple(zip(*batch))

#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):
    model.train()
    model.to(device)
    loss_values = []
    iou_values = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        total_ious = 0
        num_boxes = 0
        for images, targets in train_loader:
            # Skip batches with placeholder data
            if len(targets) == 1 and not targets[0]:
                continue
            # Skip batches with empty targets
            if any(len(target["boxes"]) == 0 for target in targets):
                continue
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            epoch_loss += losses.item()
            
            # Compute IoU for evaluation
            with torch.no_grad():
                model.eval()
                predictions = model(images)
                for i, prediction in enumerate(predictions):
                    pred_boxes = prediction["boxes"].cpu().numpy()
                    true_boxes = targets[i]["boxes"].cpu().numpy()
                    for pred_box in pred_boxes:
                        for true_box in true_boxes:
                            iou = compute_iou(pred_box, true_box)
                            total_ious += iou
                            num_boxes += 1
                model.train()
        
        avg_loss = epoch_loss / len(train_loader)
        avg_iou = total_ious / num_boxes if num_boxes != 0 else 0
        loss_values.append(avg_loss)
        iou_values.append(avg_iou)
        print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")
    
    # Plotting loss and IoU values
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(loss_values, label="Training Loss")
    plt.title("Training Loss across Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    
    plt.subplot(1, 2, 2)
    plt.plot(iou_values, label="IoU")
    plt.title("IoU across Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("IoU")
    plt.show()

    # Save model after training
    save_model(model)

# Validation function
def validate_model(model, val_loader, device):
    model.eval()
    model.to(device)
    
    with torch.no_grad():
        for images, targets in val_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            model(images)

# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"

# Split data
train_images, val_images = split_data(image_dir)

# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)

# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

# Model and optimizer
model = get_ssdlite_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")

需要从头训练的,就不跑了,摆烂了。

五、写在后面

目标检测模型门槛更高了,运行起来对硬件要求也很高,时间也很久,都是小时起步的。因此只是简单介绍,算是入个门了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1250361.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

逸学java【初级菜鸟篇】10.I/O(输入/输出)

hi&#xff0c;我是逸尘&#xff0c;一起学java吧 目标&#xff08;任务驱动&#xff09; 1.请重点的掌握I/O的。 场景&#xff1a;最近你在企业也想搞一个短视频又想搞一个存储的云盘&#xff0c;你一听回想到自己对于这些存储的基础还不是很清楚&#xff0c;于是回家开始了…

linux shell操作 - 05 IO 模型

文章目录 流IO模型阻塞IO非阻塞IOIO多路复用异步IO网络IO模型 流 可以进行IO&#xff08;input输入、output输出&#xff09;操作的内核对象&#xff1b;如文件、管道、socket…流的入口是fd (file descriptor)&#xff1b; IO模型 阻塞IO&#xff0c; 一直等待&#xff0c;…

基于Vue+SpringBoot的数字化社区网格管理系统

项目编号&#xff1a; S 042 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S042&#xff0c;文末获取源码。} 项目编号&#xff1a;S042&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 源码 & 项目录屏 二、功能模块三、开发背景四、系统展示五…

Swing程序设计(6)边界布局,网格布局

文章目录 前言一、布局介绍 1.边界布局2.网格布局3.网格组布局.总结 前言 Swing程序中还有两种方式边界布局&#xff0c;网格布局供程序员使用。这两种布局方式更能体现出软件日常制作的排列布局格式。 一、布局介绍 1.BorderLayout边界布局 语法&#xff1a;new BorderLayout …

解决几乎任何机器学习问题 -- 学习笔记(组织机器学习项目)

书籍名&#xff1a;Approaching (Almost) Any Machine Learning Problem-解决几乎任何机器学习问题 此专栏记录学习过程&#xff0c;内容包含对这本书的翻译和理解过程 我们首先来看看文件的结构。对于你正在做的任何项目,都要创建一个新文件夹。在本例中,我 将项目命名为 “p…

使用Perplexity AI免费白嫖GPT4的使用次数((智能搜索工具)

一、Perplexity AI是什么 Perplexity AI是一款高质量的智能搜索工具&#xff0c;它可以为用户提供简洁清晰的搜索体验。Perplexity AI内置了基于GPT-4的Copilot搜索功能&#xff0c;用户可以在每四个小时使用五次(白嫖GPT-4)。此外&#xff0c;Perplexity AI有免费和付费&#…

Python是个什么鬼?朋友靠它拿了5个offer

闺蜜乐乐&#xff0c;外院科班出身&#xff0c;手持专八和CATTI证书&#xff0c;没想到找工作时却碰了钉子… 半夜12点&#xff0c;乐乐跟我开启了吐槽模式&#xff1a; 拿到offer的都是小公司的翻译活儿&#xff0c;只能糊个口。稍微好点的平台要求可就多了&#xff0c;不仅语…

以“防方视角”观文件上传功能

为方便您的阅读&#xff0c;可点击下方蓝色字体&#xff0c;进行跳转↓↓↓ 01 案例概述02 攻击路径03 防方思路 01 案例概述 这篇文章来自微信公众号“NearSec”&#xff0c;记录的某师傅接到一个hw项目&#xff0c;在充分授权的情况下&#xff0c;针对客户的系统进行渗透测试…

java计算下一个整10分钟时间点

最近工作上遇到需要固定在整10分钟一个周期调度某个任务&#xff0c;所以需要这样一个功能&#xff0c;记录下 package org.example;import com.google.gson.Gson; import org.apache.commons.lang3.time.DateUtils;import java.io.InputStream; import java.util.Calendar; i…

原型 原型对象 原型链

在面向开发对象开发过程中对每一个实例添加方法&#xff0c;会使每一个对象都存在该添加方法造成空间浪费 通过对原型添加公共的属性或方法&#xff0c;使所有实例对象都可访问 原型为了共享公共的成员 prototype 原型: JS为每个构造函数提供一个属性prototype(原型),它的值…

Redis与Mysql的数据强一致性方案

目的 Redis和Msql来保持数据同步&#xff0c;并且强一致&#xff0c;以此来提高对应接口的响应速度&#xff0c;刚开始考虑是用mybatis的二级缓存&#xff0c;发现坑不少&#xff0c;于是决定自己搞 要关注的问题点 操作数据必须是唯一索引 如果更新数据不是唯一索引&#…

原生小程序图表

原生小程序使用图表 话不多说直接进入正题 官方文档: https://www.ucharts.cn/v2/#/ 下载文件 首先去gitee上把文件下载到自己的项目中 https://gitee.com/uCharts/uCharts 找到微信小程序和里面的组件 把里面src下的文件全部下载下来放入自己项目中 项目文件 新建文件…

hdlbits系列verilog解答(Exams/m2014 q4h)-44

文章目录 一、问题描述二、verilog源码三、仿真结果 一、问题描述 实现以下电路&#xff1a; 二、verilog源码 module top_module (input in,output out);assign out in;endmodule三、仿真结果 转载请注明出处&#xff01;

java集合,ArrayList、LinkedList和Vector,多线程场景下如何使用 ArrayList

文章目录 Java集合1.2 流程图关系1.3 底层实现1.4 集合与数组的区别1.4.1 元素类型1.4.2 元素个数 1.5 集合的好处1.6 List集合我们以ArrayList集合为例1.7 迭代器的常用方法1.8 ArrayList、LinkedList和Vector的区别1.8.1 说出ArrayList,Vector, LinkedList的存储性能和特性1.…

【最新版】SolidWorks 2023 SP5.0 完整版安装包+安装教程

分享模式&#xff1a;免费/绿色&#xff0c;按教程安装 下载地址&#xff1a; https://pan.xunlei.com/s/VNL0-DD_ogcRFwy-xi0HUtlyA1?pwdfzqw# 提取码&#xff1a;fzqw SOLIDWORKS 2023新版本对电脑配置要求 更多详细说明请去官网查看。 安装使用方法&#xff1a; 一、卸…

还不会配置Nginx?刷完这篇就够了

Nginx是一个开源的高性能HTTP和反向代理服务器。它可以用于处理静态资源、负载均衡、反向代理和缓存等任务。Nginx被广泛用于构建高可用性、高性能的Web应用程序和网站。它具有低内存消耗、高并发能力和良好的稳定性&#xff0c;因此在互联网领域非常受欢迎。 为什么使用Nginx…

从 RBAC 到 NGAC ,企业如何实现自动化权限管理?

随着各领域加快向数字化、移动化、互联网化的发展&#xff0c;企业信息环境变得庞大复杂&#xff0c;身份和权限管理面临巨大的挑战。为了满足身份管理法规要求并管理风险&#xff0c;企业必须清点、分析和管理用户的访问权限。如今&#xff0c;越来越多的员工采用移动设备进行…

Python Web开发基础知识篇

一&#xff0c;基础知识篇 本片文章会简单地说一些python开发web中所必须的一些基础知识。主要包括HTML和css基础、JavaScript基础、网络编程基础、MySQL数据库基础、Web框架基础等知识。 1,Web简介 Web&#xff0c;全称为World Wide Web&#xff0c;也就是WWW&#xff0c;万…

入选数据结构与算法领域内容榜第26名

入选数据结构与算法领域内容榜第26名

【广州华锐互动】Web3D云展编辑器能为展览行业带来哪些便利?

在数字时代中&#xff0c;传统的展览方式正在被全新的技术和工具所颠覆。其中&#xff0c;最具有革新意义的就是Web3D云展编辑器。这种编辑器以其强大的功能和灵活的应用&#xff0c;正在为展览设计带来革命性的变化。 广州华锐互动开发的Web3D云展编辑器是一种专门用于创建、编…