基于yolov5和desnet的猫咪识别模型

news2025/4/11 16:03:37

前言

前段时间给学校的猫咪小程序搭建了识猫模型,可以通过猫咪的照片辨别出是那只猫猫,这里分享下具体的方案,先看效果图:

源代码在文末

模型训练

在训练服务器(或你的个人PC)上拉取本仓库代码。

图片数据准备


进入`data`目录,执行`npm install`安装依赖。(需要 Node.js 环境,不确定老版本 Node.js 兼容性,建议使用最新版本。)


复制`config.demo.ts`文件并改名为`config.ts`,填写Laf云环境的`LAF_APPID`;


执行`npm start`,脚本将根据小程序数据库记录拉取小程序云存储中的图片。

如果不打算从laf拉取数据,也可以自己制作数据集,只要保证文件格式如下就可以

catface文件下面的data文件中的photos中有若干个文件夹,每个文件夹名称为id,文件夹下为图片。

环境搭建


返回仓库根目录,执行`python -m pip install -r requirements.txt`安装依赖。(需要Python>=3.8。不建议使用特别新版本的 Python,可能有兼容性问题。)


如果是linux系统,可以直接执行`bash prepare_yolov5.sh`拉取YOLOv5目标检测模型所需的代码,然后下载并预处理模型数据。如果是windows系统可以自己手动从gihub上拉取yolov5的模型。


执行`python3 data_preprocess.py`,脚本将使用YOLOv5从`data/photos`的图片中识别出猫猫并截取到`data/crop_photos`目录。

开始训练

执行`python3 main.py`,使用默认参数训练一个识别猫猫图片的模型。(你可以通过`python3 main.py --help`查看帮助来自定义一些训练参数。)程序运行结束时,你应当看到目录的export文件夹下存在`cat.onnx`和`cat.json`两个文件。(训练数据使用TensorBoard记录在`lightning_logs`文件夹下。若要查看准确率等信息,请自行运行TensorBoard。)


执行`python3 main.py --data data/photos --size 224 --name fallback`,使用修改后的参数训练一个在YOLOv5无法找到猫猫时使用的全图识别模型。程序运行结束时,你应当看到目录的export文件夹下存在`fallback.onnx`和`fallback.json`两个文件。

这里介绍下模型类的代码,我们定义了学习率,网络指定为densenet21

import torch
import torch.nn as nn
from torchvision import models
import torch.optim as optim
from pytorch_lightning import LightningModule
import torchmetrics
from typing import Tuple

class CatFaceModule(LightningModule):
    def __init__(self, num_classes: int, lr: float):
        super(CatFaceModule, self).__init__()

        self.save_hyperparameters()

        self.net = models.densenet121(num_classes=num_classes)
        self.loss_func = nn.CrossEntropyLoss()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.LongTensor], batch_idx: int) -> torch.Tensor:
        loss, acc = self.do_step(batch)

        self.log('train/loss', loss, on_step=True, on_epoch=True)
        self.log('train/acc', acc, on_step=True, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx: int):
        loss, acc = self.do_step(batch)

        self.log('val/loss', loss, on_step=False, on_epoch=True)
        self.log('val/acc', acc, on_step=False, on_epoch=True)
    
    def do_step(self, batch: Tuple[torch.Tensor, torch.LongTensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        # shape: x (B, C, H, W), y (B), w (B)
        x, y = batch

        # shape: out (B, num_classes)
        out = self.net(x)

        loss = self.loss_func(out, y)

        with torch.no_grad():
            # 每个类别分别计算准确率,以平衡地综合考虑每只猫的准确率
            accuracy_per_class = torchmetrics.functional.accuracy(out, y, task="multiclass", num_classes=self.hparams['num_classes'], average=None)
            # 去掉batch中没有出现的类别,这些位置为nan
            nan_mask = accuracy_per_class.isnan()
            accuracy_per_class = accuracy_per_class.masked_fill(nan_mask, 0)
            # 剩下的位置取均值
            acc = accuracy_per_class.sum() / (~nan_mask).sum()
        
        return loss, acc

    def configure_optimizers(self) -> optim.Optimizer:
        return optim.Adam(self.parameters(), lr=self.hparams['lr'])

在模型训练完毕后可以运行我编写的modelTest,在这个文件中替换图片为自己的图片,观察输出是否正常,正常输出是这样的:

在这个输出中,通过yolo检测了图片中是否含有猫咪,通过densenet对图片所属于的类进行概率计算,概率和id按照概率从大到小排序返回。

接口实现

我们训练了两个densenet模型,一个是全图像的输入为228的模型a,一个是输入图像为128的模型b,当请求打到服务器时,应用程序会先通过yolo检测是否有猫,有的话就截取猫咪图像,使用模型b;否则不截取,使用模型a。

以下是代码:

from typing import Any
from werkzeug.datastructures import FileStorage

import torch
from PIL import Image
import numpy as np
import onnxruntime
from flask import Flask, request
from dotenv import load_dotenv
import os
import json
import time
from base64 import b64encode
from hashlib import sha256

load_dotenv("./env", override=True)

HOST_NAME = os.environ['HOST_NAME']
PORT = int(os.environ['PORT'])

SECRET_KEY = os.environ['SECRET_KEY']
TOLERANT_TIME_ERROR = int(os.environ['TOLERANT_TIME_ERROR']) # 可以容忍的时间戳误差(s)

IMG_SIZE = int(os.environ['IMG_SIZE'])
FALLBACK_IMG_SIZE = int(os.environ['FALLBACK_IMG_SIZE'])

CAT_BOX_MAX_RET_NUM = int(os.environ['CAT_BOX_MAX_RET_NUM']) # 最多可以返回的猫猫框个数
RECOGNIZE_MAX_RET_NUM = int(os.environ['RECOGNIZE_MAX_RET_NUM']) # 最多可以返回的猫猫识别结果个数

print("==> loading models...")
assert os.path.isdir("export"), "*** export directory not found! you should export the training checkpoint to ONNX model."

crop_model = torch.hub.load('yolov5', 'custom', 'yolov5/yolov5m.onnx', source='local')

with open("export/cat.json", "r") as fp:
    cat_ids = json.load(fp)
cat_model = onnxruntime.InferenceSession("export/cat.onnx", providers=["CPUExecutionProvider"])

with open("export/cat.json", "r") as fp:
    fallback_ids = json.load(fp)
fallback_model = onnxruntime.InferenceSession("export/cat.onnx", providers=["CPUExecutionProvider"])

print("==> models are loaded.")

app = Flask(__name__)
# 限制post大小为10MB
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024

def wrap_ok_return_value(data: Any) -> str:
    return json.dumps({
        'ok': True,
        'message': 'OK',
        'data': data
    })

def wrap_error_return_value(message: str) -> str:
    return json.dumps({
        'ok': False,
        'message': message,
        'data': None
    })

def check_signature(photo: FileStorage, timestamp: int, signature: str) -> bool:
    if abs(timestamp - time.time()) > TOLERANT_TIME_ERROR:
        return False
    photoBase64 = b64encode(photo.read()).decode()
    photo.seek(0) # 重置读取位置,避免影响后续操作
    signatureData = (photoBase64 + str(timestamp) + SECRET_KEY).encode()
    return signature == sha256(signatureData).hexdigest()

@app.route("/recognizeCatPhoto", methods=["POST"])
@app.route("/recognizeCatPhoto/", methods=["POST"])
def recognize_cat_photo():
    try:
        photo = request.files['photo']
        timestamp = int(request.form['timestamp'])
        signature = request.form['signature']
        if not check_signature(photo, timestamp=timestamp, signature=signature):
            return wrap_error_return_value("fail signature check.")
        
        src_img = Image.open(photo).convert("RGB")
        # 使用 YOLOv5 进行目标检测,结果为[{xmin, ymin, xmax, ymax, confidence, class, name}]格式
        results = crop_model(src_img).pandas().xyxy[0].to_dict('records')
        # 过滤非cat目标
        cat_results = list(filter(lambda target: target['name'] == 'cat', results))
        
        if len(cat_results) >= 1:
            cat_idx = int(request.form['catIdx']) if 'catIdx' in request.form and int(request.form['catIdx']) < len(cat_results) else 0
            
            # 裁剪出(指定的)cat
            cat_result = cat_results[cat_idx]
            crop_box = cat_result['xmin'], cat_result['ymin'], cat_result['xmax'], cat_result['ymax']
            # 裁剪后直接resize到正方形
            src_img = src_img.crop(crop_box).resize((IMG_SIZE, IMG_SIZE))

            # 输入到cat模型
            img = np.array(src_img, dtype=np.float32).transpose((2, 0, 1)) / 255
            scores = cat_model.run([node.name for node in cat_model.get_outputs()], {cat_model.get_inputs()[0].name: img[np.newaxis, :]})[0][0].tolist()

            # 按概率排序
            cat_id_with_score = sorted([dict(catID=cat_ids[i], score=scores[i]) for i in range(len(cat_ids))], key=lambda item: item['score'], reverse=True)
        else:
            # 没有检测到cat
            # 整张图片直接resize到正方形
            src_img = src_img.resize((FALLBACK_IMG_SIZE, FALLBACK_IMG_SIZE))

            img = np.array(src_img, dtype=np.float32).transpose((2, 0, 1)) / 255
            scores = fallback_model.run([node.name for node in fallback_model.get_outputs()], {fallback_model.get_inputs()[0].name: img[np.newaxis, :]})[0][0].tolist()

            # 按概率排序
            cat_id_with_score = sorted([dict(catID=fallback_ids[i], score=scores[i]) for i in range(len(fallback_ids))], key=lambda item: item['score'], reverse=True)

        return wrap_ok_return_value({
            'catBoxes': [{
                'xmin': item['xmin'],
                'ymin': item['ymin'],
                'xmax': item['xmax'],
                'ymax': item['ymax']
            } for item in cat_results][:CAT_BOX_MAX_RET_NUM],
            'recognizeResults': cat_id_with_score[:RECOGNIZE_MAX_RET_NUM]
        })
    except BaseException as err:
        return wrap_error_return_value(str(err))

if __name__ == "__main__":
    app.run(host=HOST_NAME, port=PORT, debug=False)

我们可以在本地运行,如果想测试的小伙伴可以把接口中密钥校验的代码删除,然后直接发送post请求即可。

源码链接

cat-face: 猫脸识别程序,使用yolov5和densenet分类

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1699511.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

moviepy入门

1. 简介 由于恶心的工作和没有规划的部门安排&#xff0c;我被排到了算法部门&#xff0c;从事和算法没有半毛钱关系的业务上&#xff0c;也就是。。。搞视频。咋说呢&#xff1f;视频这东西我没有一点基础&#xff0c;还好有前人写好的代码&#xff0c;用的是moviepy和ffmpeg…

网络攻防概述(基础概念)

文章目录 APTAPT概念APT攻击过程 网络空间与网络空间安全网络空间(Cyberspace)网络空间安全(Cyberspace Security) 网络安全属性机密性(Confidentiality或Security)完整性(Integrity)可用性&#xff08;Availability&#xff09;不可否认性&#xff08;Non-repudiation&#xf…

通过unsplash引入图片素材

如果您还没听说过——当您需要无版权费的照片用于项目时&#xff0c;无论是否用于商业目的&#xff0c;Unsplash 都是您的不二之选。我自己也经常用它来获取大型背景图像。 虽然他们为开发者提供了出色的 API&#xff0c;但他们还为您提供了通过 URL 直接访问随机图片的选项。…

开源博客项目Blog .NET Core源码学习(27:App.Hosting项目结构分析-15)

本文学习并分析App.Hosting项目中后台管理页面的角色管理页面。   角色管理页面用于显示、检索、新建、编辑、删除角色数据同时支持按角色分配菜单权限&#xff0c;以便按角色控制后台管理页面的菜单访问权限。角色管理页面附带一新建及编辑页面&#xff0c;以支撑新建和编辑…

本地连接github仓库

【1】新建github仓库 【2】本地克隆并提交 $ git clone https://github.com/TomJourney/soil.git Cloning into soil... warning: You appear to have cloned an empty repository.pacosonDESKTOP-E4IASRJ MINGW64 /d/github/TomJourney/soil (master) $ git add readme.txtpa…

Android HAL到Framework

一、为什么需要Framwork? Framework实际上是⼀个应⽤程序的框架&#xff0c;提供了很多服务&#xff1a; 1、丰富⽽⼜可扩展的视图&#xff08;Views&#xff09;&#xff0c; 可以⽤来构建应⽤程序&#xff0c;它包括列表&#xff08;lists&#xff09;&#xff0c;⽹格&am…

指针(6)

1. sizeof和strlen的对比 1.1 sizeof 在学习操作符的时候&#xff0c;我们学习了 sizeof &#xff0c; sizeof 计算变量所占内存内存空间大小的&#xff0c;单位是字节&#xff0c;如果操作数是类型的话&#xff0c;计算的是使⽤类型创建的变量所占内存空间的大小。 sizeof 只…

精品丨快速申请免费https证书

https域名证书对提高网站排名有一定的好处&#xff0c;所以当今很多企业为了给网站一个好的安全防护&#xff0c;就会去申请该证书。如今很多企业虽然重视网站的安全防护&#xff0c;但是也重视成本&#xff0c;所以为了节约成本会考虑申请免费的https证书。 第一个好处 企业不…

力扣496. 下一个更大元素 I

Problem: 496. 下一个更大元素 I 文章目录 题目描述思路复杂度Code 题目描述 思路 因为题目说nums1是nums2的子集&#xff0c;那么我们先把nums2中每个元素的下一个更大元素算出来存到一个映射里&#xff0c;然后再让nums1中的元素去查表即可 复杂度 时间复杂度: O ( n 1 n 2…

吉林大学计科21级《软件工程》期末考试真题

文章目录 21级期末考试题一、单选题&#xff08;2分一个&#xff0c;十个题&#xff0c;一共20分&#xff09;二、问答题&#xff08;5分一个&#xff0c;六个题&#xff0c;一共30分&#xff09;三、分析题&#xff08;一个10分&#xff0c;一共2个&#xff0c;共20分&#xf…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-22讲 RTC 时钟设置

前言&#xff1a; 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM&#xff08;MX6U&#xff09;裸机篇”视频的学习笔记&#xff0c;在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

十四天学会Vue——Vue核心(理论+实战)(第一天)上篇

&#xff01;&#xff01;&#xff01;声明必看&#xff1a;由于本篇开始就写了Vue&#xff0c;内容过多&#xff0c;本篇部分内容还有待完善&#xff0c;小编先去将连续更新的js高阶第四天完成~本篇部分待完善内容明日更新 一、Vue核心&#xff08;上篇&#xff09; 热身top…

mysql - 索引原理

mysql索引原理 文中的查询, 以该表结构为例 CREATE TABLE user (id int NOT NULL COMMENT id,name varchar(255) COLLATE utf8mb4_bin NOT NULL COMMENT 姓名,age int NOT NULL COMMENT 年龄,sex tinyint(1) NOT NULL COMMENT 性别,phone varchar(255) CHARACTER SET utf8mb4…

06中间件RTOS/CP

Autosar CP 操作系统详解-CSDN博客 1. 什么是RTOS &#xff1f; RTOS&#xff0c;英文全称是 Real-time Operation System&#xff0c;中文就是 实时操作系统&#xff0c;又称及时操作系统。 实时操作系统&#xff0c;是指当外界事件或数据产生时&#xff0c;能够接受并以足…

GEC210编译环境搭建

一、下载编译工具链 下载&#xff1a;点击跳转 二、解压到 /usr/local/arm 目录 sudo mv gec210.zip /usr/local/arm cd /usr/local/arm sudo unzip gec210.zip 三、添加到环境变量 PATH/usr/local/arm/arm-cortex_a8-linux-gnueabi-4.7.3/bin:$PATH 四、测试验证 在终端…

微信小程序如何跳转微信公众号

1. 微信小程序如何跳转微信公众号 1.2. 微信公众号配置 登录微信公众号&#xff0c;点击【小程序管理】&#xff1a;   点击【添加】&#xff1a;   点击【关联小程序】&#xff1a;   输入小程序进行关联&#xff1a; 1.2. 微信小程序配置 登录微信小程序&#xf…

力扣刷题---LCS 02. 完成一半题目【简单】

题目描述 有 N 位扣友参加了微软与力扣举办了「以扣会友」线下活动。主办方提供了 2*N 道题目&#xff0c;整型数组 questions 中每个数字对应了每道题目所涉及的知识点类型。 若每位扣友选择不同的一题&#xff0c;请返回被选的 N 道题目至少包含多少种知识点类型。 示例 1&…

【NumPy】关于numpy.sum()函数,看这一篇文章就够了

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

JAVA实现图书管理系统(初阶)

一.抽象出对象: 1.要有书架&#xff0c;图书&#xff0c;用户&#xff08;包括普通用户&#xff0c;管理员用户&#xff09;。根据这些我们可以建立几个包&#xff0c;来把繁杂的代码分开&#xff0c;再通过一个类来把这些&#xff0c;对象整合起来实现系统。说到整合&#xf…

C++ List完全指南:使用方法与自定义实现

文章目录 list的使用几种构造函数 list的实现1.节点类的定义1.1节点类的构造函数 2.正向迭代器实现2.1operator*重载2.2operator->重载2.3operator重载2.4operator--2.5operator和operator&#xff01; 3.反向迭代器实现3.1operator*重载3.2operator->重载3.3operator重载…