基于PaddleClas的人物年龄分类项目

news2024/9/27 12:25:33

目录

一、任务概述

二、算法研发

2.1 下载数据集

2.2 数据集预处理

2.3 安装PaddleClas套件

2.4 算法训练

2.5 静态图导出

2.6 静态图推理

三、小结


一、任务概述

    最近遇到个需求,需要将图像中的人物区分为成人和小孩,这是一个典型的二分类问题,打算采用飞桨的图像分类套件PaddleClas来完成算法研发。本文记录相关流程。

二、算法研发

2.1 下载数据集

    本文采用MaGaAge_Asian数据集,该数据集主要由亚洲人图片组成,训练集包含40000张图像,验证集包含3495张图像,每张图像都有对应的年龄真值,所有图像均处理成了统一的大小,宽178像素,高218像素。

数据集地址下载链接。数据集部分示例如下图所示:

    该数据集本意是用来做年龄预测的,属于一个数值回归任务,本文将其变成二分类任务,以13岁年龄为界限,小于该年龄的属于小孩,大于该年龄的属于成人。这里之所以选择13岁,因为这个任务是需要筛选出长得很“像”小孩的小孩,13岁以上的青少年很多本身已经长的像成人了,因此,选择13岁作为分界线。

    下面首先对该数据集进行处理。

2.2 数据集预处理

    MaGaAge_Asian数据集每张图片对应的人物年龄存放在list文件夹的两个文件中,其中train_age.txt存放训练集对应的年龄真值,test_age.txt存放验证集对应的年龄真值。下面要写一个脚本,将所有小于13岁的图片移动到一个文件夹内,所有大于等于13岁的图片移动到另一个文件夹内。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@文件        :split_asian.py
@说明        :拆分megaage_asian数据集,将小于13岁的移动到一个文件夹,大于等于13岁的移动到另一个文件夹
@时间        :2024/07/16 09:11:16
@作者        :Bin Qian
@版本        :1.0
'''


import os
import cv2

thr = 13 # 年龄阈值

# 读取年龄列表
agefile = 'megaage_asian/list/test_age.txt'
f=open(agefile) 
ageLst = f.read().splitlines()
f.close() 

# 读取图像
imgFolder = 'megaage_asian/val'
imgnames = os.listdir(imgFolder)
index = 50000
for imgname in imgnames:
    imgPath = os.path.join(imgFolder,imgname)
    img = cv2.imread(imgPath)
    if img is None:
        continue
    print(imgPath)
    imgindex = int(imgname.split('.')[0])
    age = int(ageLst[imgindex-1])
    if age < thr:
        dstFolder = 'ageclas/child'
    else:
        dstFolder = 'ageclas/adult'
    
    savePath = os.path.join(dstFolder,str(index)+'_asian.jpg')
    cv2.imwrite(savePath,img)
    index += 1
print('完成')

值得注意的是MaGaAge_Asian数据集中有很多质量较差的图像,这些“脏”图像会影响学习效果,最好手工检查这些数据并将其剔除。

另外,为了能够取得更好的效果,本文从互联网和FFHQ数据集里面再挑选出一些小孩和成人的照片进行补充。部分代码如下:

import os
import cv2

# 读取图像
imgFolder = 'adult'
imgnames = os.listdir(imgFolder)
index = 1
for imgname in imgnames:
    imgPath = os.path.join(imgFolder,imgname)
    img = cv2.imread(imgPath)
    if img is None:
        continue
    print(imgPath)
    dstFolder = 'ageclas/adult'
    savePath = os.path.join(dstFolder,str(index)+'_data.jpg')
    cv2.imwrite(savePath,img)
    index += 1
print('完成')

补充完整后,最后对整理好的数据集进行拆分,并且获得对应的文件列表:

# 导入系统库
import os
import random
import cv2


# 定义参数
img_folder = 'ageclas'
trainlst = 'train_list.txt'
vallst = 'val_list.txt'
ratio = 0.95 # 训练集占比
labellst='label.txt'
 

def writeLst(lstpath,namelst):
    '''
    保存文件列表
    '''
    print('正在写入 '+lstpath)
    random.shuffle (namelst)
    # 写入训练样本文件
    f=open(lstpath, 'a', encoding='utf-8')
    for i in range(len(namelst)):
        text = namelst[i]+'\n'
        f.write(text)
    f.close()
    print(lstpath+ '已完成写入')
    

 
def main():
    '''
    主函数
    '''
    # 查找文件夹
    folderlst = os.listdir(img_folder)
    print('共找到 %d 个文件夹' % len(folderlst))
    
    # 循环处理
    trainnamelst = list()
    valnamelst = list()
    labelnamelst = list()
    for i in range(len(folderlst)):
        class_name = folderlst[i]
        class_label = i
        print('开始处理 '+class_name+' 文件夹')
        
        # 获取子文件夹文件列表
        filenamelst = os.listdir(os.path.join(img_folder,class_name))
        totalNum = len(filenamelst)
        print('当前文件夹图片数量为: ' + str(totalNum)) 
        trainNum = int(ratio*totalNum)
        text =  str(class_label)+ ' ' + class_name
        labelnamelst.append(text)
        
        # 检查并校验图像
        for j in range(totalNum):
            imgpath = os.path.join(img_folder,class_name,filenamelst[j])
            img = cv2.imread(imgpath, cv2.IMREAD_COLOR)
            if img is None:
                continue
            text = imgpath + ' ' + str(class_label)
            if j <= trainNum: 
                trainnamelst.append(text)
            else:
                valnamelst.append(text)
                
    writeLst(trainlst,trainnamelst)
    writeLst(vallst,valnamelst)   
    writeLst(labellst,labelnamelst)     
    print('全部完成')


if __name__ == '__main__':
    '''程序入口'''
    main()

运行后会生成train_lst.txt、val_lst.txt以及label.txt三个文件,有了这三个文件就可以使用PaddleClas套件进行算法研发了。

2.3 安装PaddleClas套件

git clone https://gitee.com/paddlepaddle/PaddleClas.git
cd PaddleClas
sudo python setup.py install

2.4 算法训练

在PaddleClas目录下新建一个配置文件config_lcnet.yaml,采用PPLCNet_x0_5模型来训练,配置文件代码如下:

# global configs
Global:
  checkpoints: null
  pretrained_model: null
  output_dir: ./output/
  device: gpu
  save_interval: 5
  eval_during_train: True
  eval_interval: 5
  epochs: 200
  print_batch_step: 10
  use_visualdl: True
  # used for static mode and model export
  image_shape: [3, 224, 224]
  save_inference_dir: ./output/inference
# model architecture
Arch:
  name: PPLCNet_x0_5
  class_num: 2
 
# loss function config for traing/eval process
Loss:
  Train:
    - CELoss:
        weight: 1.0
        epsilon: 0.1
  Eval:
    - CELoss:
        weight: 1.0


Optimizer:
  name: Momentum
  momentum: 0.9
  lr:
    name: Cosine
    learning_rate: 0.8
    warmup_epoch: 5
  regularizer:
    name: 'L2'
    coeff: 0.00003


# data loader for train and eval
DataLoader:
  Train:
    dataset:
      name: ImageNetDataset
      image_root: ../process_data/
      cls_label_path: ../process_data/train_list.txt
      transform_ops:
        - DecodeImage:
            to_rgb: True
            channel_first: False
        - ResizeImage:
            size: [224,224]
        - RandFlipImage:
            flip_code: 1
        - NormalizeImage:
            scale: 1.0/255.0
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
            order: ''

    sampler:
      name: DistributedBatchSampler
      batch_size: 64
      drop_last: False
      shuffle: True
    loader:
      num_workers: 4
      use_shared_memory: True

  Eval:
    dataset: 
      name: ImageNetDataset
      image_root: ../process_data/
      cls_label_path: ../process_data/val_list.txt
      transform_ops:
        - DecodeImage:
            to_rgb: True
            channel_first: False
        - ResizeImage:
            size: [224,224]
        - NormalizeImage:
            scale: 1.0/255.0
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
            order: ''
    sampler:
      name: DistributedBatchSampler
      batch_size: 64
      drop_last: False
      shuffle: False
    loader:
      num_workers: 4
      use_shared_memory: True

Infer:
  infer_imgs: "../testimgs/10.jpg"
  batch_size: 1
  transforms:
    - DecodeImage:
        to_rgb: True
        channel_first: False
    - ResizeImage:
        size: [224,224]
    - NormalizeImage:
        scale: 1.0/255.0
        mean: [0.485, 0.456, 0.406]
        std: [0.229, 0.224, 0.225]
        order: ''
    - ToCHWImage:
  PostProcess:
    name: Topk
    topk: 1
    class_id_map_file: "../process_data/label.txt"

Metric:
  Train:
    - TopkAcc:
        topk: [1]
  Eval:
    - TopkAcc:
        topk: [1]

然后使用下面的命令进行训练:

export CUDA_VISIBLE_DEVICES=0,1
python3 -m paddle.distributed.launch \
    --gpus="0,1" \
    tools/train.py \
        -c config_lcnet.yaml 

训练完成后可以使用下面的命令可视化查看训练结果:

visualdl --logdir results/vdl

运行效果如下:

可以看到,基本在epoch=100以后就收敛了,最高top1准确率达到97.5%,准确率还是比较高的。

下面可以使用动态图对单张图像进行测试,命令如下:

python3 tools/infer.py -c config_lcnet.yaml -o Global.pretrained_model=output/PPLCNet_x0_5/best_model

输出如下:

[{'class_ids': [1], 'scores': [0.93522], 'file_name': '../testimgs/10.jpg', 'label_names': ['adult']}]

2.5 静态图导出

为了方便后面进行模型部署,将训练好的最佳模型进行静态图导出。具体命令如下:

python3 tools/export_model.py \
    -c config_lcnet.yaml \
    -o Global.pretrained_model=output/PPLCNet_x0_5/best_model \
    -o Global.save_inference_dir=output/inference

导出的静态图模型存放在output/inference文件夹下面,整个模型参数加起来不超过3M,因此可以看出这个训练好的PPLCNet_x0_5模型是一个非常轻量级的模型。

2.6 静态图推理

下面使用静态图来进行推理。在推理前先使用visualdl工具查看下静态图模型的输入和输出,这将为编写推理脚本奠定基础。

可以看到,输入是[batch,3,224,224]的float型图像数据,输出是[batch,2]的float型数据。尤其是输出的两个值,代表的是两个类别的概率。

有了上面的分析,下面可以用PaddleInference写一个推理脚本infer.py:

import cv2
import numpy as np
from paddle.inference import create_predictor
from paddle.inference import Config as PredictConfig

# 加载静态图模型
model_path = "./output/inference/inference.pdmodel"
params_path = "./output/inference/inference.pdiparams"
pred_cfg = PredictConfig(model_path, params_path)
pred_cfg.enable_memory_optim()  # 启用内存优化
pred_cfg.switch_ir_optim(True)
pred_cfg.enable_use_gpu(500, 0)  # 启用GPU推理
predictor = create_predictor(pred_cfg)  # 创建PaddleInference推理器

# 解析模型输入输出
input_names = predictor.get_input_names()
input_handle = {}
for i in range(len(input_names)):
    input_handle[input_names[i]] = predictor.get_input_handle(input_names[i])
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])

# 图像预处理
img = cv2.imread("../testimgs/10.jpg", flags=cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32)
PIXEL_MEANS =(0.485, 0.456, 0.406)    # RGB格式的均值和方差
PIXEL_STDS = (0.229, 0.224, 0.225)
img/=255.0
img-=np.array(PIXEL_MEANS)
img/=np.array(PIXEL_STDS)
img = np.transpose(img[np.newaxis, :, :, :], (0, 3, 1, 2))

# 预测
input_handle["x"].copy_from_cpu(img)
predictor.run()
results = output_handle.copy_to_cpu()

# 后处理
results = results.squeeze(0)
if results[0]>results[1]:
    print('小孩'+"  "+str(results[0]))
else:
    print('大人'+"  "+str(results[1]))

从网上随便找两张照片,运行效果如下:

输出结果:

小孩  0.7256172

输出结果:

大人  0.9533998

可以看到,推理效果还是比较满意的。

三、小结

本文以项目为主线,使用了PaddleClas算法套件解决了年龄分类问题。后续读者如果想要深入学习PaddlePaddle(飞桨)及相关算法套件,可以关注我的书籍(链接)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1949370.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计模式笔记(一)

目录 设计模式共有23种&#xff0c;也可称为GOF23 单例模式&#xff08;重点&#xff0c;常用&#xff09; 工厂模式 代理模式&#xff1a;&#xff08;SpringAOP的底层原理&#xff09; 静态代理模式&#xff1a;&#xff08;写死一个代理类Proxy&#xff09; 动态代理模…

【Java版数据结构】初识泛型

看到这句话的时候证明&#xff1a;此刻你我都在努力 加油陌生人 br />个人主页&#xff1a;Gu Gu Study专栏&#xff1a;Java版数据结构 喜欢的一句话&#xff1a; 常常会回顾努力的自己&#xff0c;所以要为自己的努力留下足迹 喜欢的话可以点个赞谢谢了。 作者&#xff1…

学习笔记:MySQL数据库操作5

1. 触发器&#xff08;Triggers&#xff09; 触发器是数据库的一种高级功能&#xff0c;它允许在执行特定数据库操作&#xff08;如INSERT、UPDATE、DELETE&#xff09;之前或之后自动执行一段代码。 1.1 创建商品和订单表 商品表&#xff08;goods&#xff09; gid: 商品编号…

navicat15安装破解

下载地址&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/19RlXTArDfNxT5n98A0GbvQ 提取码&#xff1a;qtew 破解教程 1、运行注册机&#xff0c;勾选Backup、Host和Navicat v15&#xff0c;如图所示。然后点击Patch按钮&#xff0c;找到Navicat Premium 15安装路径下的…

什么是模型无关方法?

「AI秘籍」系列课程&#xff1a; 人工智能应用数学基础人工智能Python基础人工智能基础核心知识人工智能BI核心知识人工智能CV核心知识AI 进阶&#xff1a;企业项目实战 可直接在橱窗里购买&#xff0c;或者到文末领取优惠后购买&#xff1a; 可以与任何模型一起使用的所有强…

破局产品同质化:解锁3D交互式营销新纪元!

近年来&#xff0c;随着数字体验经济的蓬勃发展&#xff0c;3D交互式营销作为一种创新手段迅速崛起&#xff0c;它巧妙地解决了传统产品展示中普遍存在的缺乏差异性和互动性的问题&#xff0c;使您的产品在激烈的市场竞争中独树一帜&#xff0c;脱颖而出。 若您正面临产品营销…

抖音直播弹幕数据逆向:websocket和JS注入

&#x1f50d; 思路与步骤详解 &#x1f575;️‍♂️ 思路介绍 首先&#xff0c;我们通过抓包工具进入的直播间&#xff0c;捕获其网络通信数据&#xff0c;重点关注WebSocket连接。发现直播弹幕数据通过WebSocket传输&#xff0c;这种方式比传统的HTTP更适合实时数据的传输。…

昇思25天学习打卡营第24天 | Pix2Pix实现图像转换

昇思25天学习打卡营第24天 | Pix2Pix实现图像转换 文章目录 昇思25天学习打卡营第24天 | Pix2Pix实现图像转换Pix2Pix模型cGANCGAN的损失函数 数据网络构建生成器判别器Pix2Pix网络 总结打卡 Pix2Pix模型 Pix2Pix是基于条件生成对抗网络&#xff08;cGAN, Condition Generativ…

如何在测试中保护用户隐私!

在当今数据驱动的时代&#xff0c;用户隐私保护成为了企业和开发团队关注的焦点。在软件测试过程中&#xff0c;处理真实用户数据时保护隐私尤为重要。本文将介绍如何在测试中保护用户隐私&#xff0c;并提供具体的方案和实战演练。 用户隐私保护的重要性 用户隐私保护不仅是法…

Qt自定义带前后缀图标的PushButton

写在前面 Qt提供QPushButton不满足带前后缀图标的需求&#xff0c;因此考虑自定义实现带前后缀图标的PushButton&#xff0c;方便后续快速使用。 效果如下&#xff1a; 同时可设置前后缀图标和文本之间间隙&#xff1a; 代码实现 通过前文介绍的Qt样式表底层实现 可以得…

linux ftp操作记录

一.ftp 创建用户 passwd: user ftpuser does not exist 如果你遇到 passwd: user ftpuser does not exist 的错误&#xff0c;这意味着系统中不存在名为 ftpuser 的用户。你需要首先确认FTP用户是否是系统用户&#xff0c;还是FTP服务器软件&#xff08;如Pure-FTPd&#xff…

类和对象:完结

1.再深构造函数 • 之前我们实现构造函数时&#xff0c;初始化成员变量主要使⽤函数体内赋值&#xff0c;构造函数初始化还有⼀种⽅ 式&#xff0c;就是初始化列表&#xff0c;初始化列表的使⽤⽅式是以⼀个冒号开始&#xff0c;接着是⼀个以逗号分隔的数据成 员列表&#xf…

redis的使用场景

1. redis的使用场景 redis使用场景的案例&#xff1a;[1]热点数据的缓存[2]分布式锁[3]短信业务&#xff08;登录注册时&#xff09;2. redis实现注册登录功能 代码 在发送验证码时&#xff0c;先判断数据库是否有该手机号&#xff0c;有则发送验证码&#xff08;此时redis缓存…

基于微信小程序+SpringBoot+Vue的自习室选座与门禁系统(带1w+文档)

基于微信小程序SpringBootVue的自习室选座与门禁系统(带1w文档) 基于微信小程序SpringBootVue的自习室选座与门禁系统(带1w文档) 本课题研究的研学自习室选座与门禁系统让用户在小程序端查看座位&#xff0c;预定座位&#xff0c;支付座位价格&#xff0c;该系统让用户预定座位…

Jmeter三种方式获取数组中多个数据并将其当做下个接口参数入参【附带JSON提取器和CSV格式化】

目录 一、传统方式-JOSN提取器获取接口返回值 1、接口调用获取返回值 2、添加JSON提取器 3、调试程序查看结果 4、添加循环控制器 5、设置count计数器 6、添加请求 7、执行请求 二、CSV参数化 1、将结果写入后置处理程序 2、设置循环处理器 3、添加CSV文件 4、设置…

【机器学习】用Jupyter Notebook实现并探索单变量线性回归的代价函数以及遇到的一些问题

引言 在机器学习中&#xff0c;代价函数&#xff08;Cost Function&#xff09;是一个用于衡量模型预测值与实际值之间差异的函数。在监督学习中&#xff0c;代价函数是评估模型性能的关键工具&#xff0c;它可以帮助我们了解模型在训练数据上的表现&#xff0c;并通过优化过程…

IPD推行成功的核心要素(十五)项目管理提升IPD相关项目交付效率和用户体验

研发项目往往包含很多复杂的流程和具体的细节。因此&#xff0c;一套完整且标准的研发项目管理制度和流程对项目的推进至关重要。研发项目管理是成功推动创新和技术发展的关键因素。然而在实际管理中&#xff0c;研发项目管理常常面临着需求不确定、技术风险、人员素质、成本和…

PyTorch安装CUDA标准流程(可解决大部分GPU无法使用问题)

最近一段时间在研究PyTorch中的GPU的使用方法&#xff0c;之前曾经安装过CUDA&#xff0c;不过在PyTorch中调用CUDA时无法使用。考虑到是版本不兼容问题&#xff0c;卸载后尝试了其他的版本&#xff0c;依旧没有能解决问题&#xff0c;指导查阅了很多资料后才找到了解决方案。 …

uni-app声生命周期

应用的生命周期函数在App.vue页面 onLaunch:当uni-app初始化完成时触发&#xff08;全局触发一次&#xff09; onShow:当uni-app启动&#xff0c;或从后台进入前台时显示 onHide:当uni-app从前台进入后台 onError:当uni-app报错时触发,异常信息为err 页面的生命周期 onLoad…

数据治理之“财务一张表”

前言 信息技术的发展&#xff0c;伴随企业业务系统的纷纷建设&#xff0c;提升业务处理效率的同时&#xff0c;也将企业的整体主价值链流程分成了一段一段的业务子流程&#xff0c;很多情况下存在数据上报延迟、业务协作不顺畅、计划反馈不及时、库存积压占资多……都可以从数据…