【机器学习】sklearn的集成学习用于图像分类从0到1,注意点和坑点

news2024/11/14 15:29:56

文章目录

  • 前言
  • 1.需求分析
    • 1.1 场景
    • 1.2 解决方案
  • 2. 代码
    • 2.1 提取特征
    • 2.2 构建分类器
    • 2.4 集成模型
    • 2.5 总的训练代码
  • 3.fast api 封装
  • 4.总结


前言

深度学习崛起后,好像机器学习就没落了,但在固定场景下,还是很好用的。下面就是展厅项目的识别任务。老规矩,集成学习的基础知识点少讲,或者不讲,因为这种文章已经很多了。主要是基于场景的业务问题解决。


1.需求分析

1.1 场景

在这里插入图片描述
画框的这几个小人会移动,移动到左边靠近水的地方,要报警。目标检测比如yolov之类的是可以实现的,但现在想学习一下sklearn的集成学习。所以采用了这个方式。

挑战,灯光的变化
在这里插入图片描述

1.2 解决方案

摄像头是不动的,所以用opencv把这部分切割出来做分类。有人和没人两类 0 表示有人, 1表示没人。
在这里插入图片描述

2. 代码

2.1 提取特征

提取图像特征的方式有很多种,这里介绍两张haar和lbp

# 提取Haar特征
def extract_haar_features(images):
    features = []

    for image in images:
        # a = 1
        # feature = cv2.HOGDescriptor().compute(image)
        feature = cv2.HOGDescriptor().compute(image).flatten()
        features.append(feature)

    return np.array(features)

# 提取lbp特征
def extract_lbp_features(images):
    features = []

    for image in images:
        # a = 1
        feature = local_binary_pattern(image, 8, 1, method='uniform').flatten()
        features.append(feature)

    return np.array(features)

2.2 构建分类器

我这里用了决策树、随机森林、KNN、SVM四个分类器,因为是sklearn封装的,调用起来很简单。

# 构建决策树分类器
def build_decision_tree_classifier(X_train, y_train):
    clf = DecisionTreeClassifier(max_depth=5)
    clf.fit(X_train, y_train)
    return clf

# 构建随机森林分类器
def build_random_forest_classifier(X_train, y_train):
    clf = RandomForestClassifier(n_estimators=50, max_depth=3)
    clf.fit(X_train, y_train)
    return clf

# 构建KNN分类器
def build_knn_classifier(X_train, y_train):
    clf = KNeighborsClassifier(n_neighbors=10,algorithm='kd_tree')
    clf.fit(X_train, y_train)
    return clf

# 构建SVM分类器
def build_svm_classifier(X_train, y_train):
    clf = SVC(kernel='linear', C=0.025)
    clf.fit(X_train, y_train)
    return clf

2.4 集成模型

构建集成模型并保存权重,clf最终的预测是由几个分类器投票形成的。

# 构建集成模型
def build_ensemble_model(X_train, y_train):
    clfs = []

    clf1 = build_decision_tree_classifier(X_train, y_train)
    clfs.append(('dt', clf1))

    clf2 = build_random_forest_classifier(X_train, y_train)
    clfs.append(('rf', clf2))

    clf3 = build_knn_classifier(X_train, y_train)
    clfs.append(('knn', clf3))

    clf4 = build_svm_classifier(X_train, y_train)
    clfs.append(('svm', clf4))

    eclf = VotingClassifier(estimators=clfs, voting='hard')
    
    # 输出训练过程
    for clf in [clf1,clf2,clf3,clf4, eclf]:
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        print(f'{clf.__class__.__name__} Accuracy: {acc}')

    eclf.fit(X_train, y_train)

    # 保存模型权重
    joblib.dump(eclf, target_weight_path)
    


    return eclf

2.5 总的训练代码

import cv2
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import joblib
from util_help import read_images
import os
import pickle
from skimage.feature import local_binary_pattern

model_name = "people_drown"
weight_name = model_name + ".joblib"
save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),"weight",model_name)
os.makedirs(save_dir,exist_ok=True)
target_weight_path = os.path.join(save_dir,weight_name)
# 提取Haar特征
def extract_haar_features(images):
    features = []

    for image in images:
        # a = 1
        # feature = cv2.HOGDescriptor().compute(image)
        feature = cv2.HOGDescriptor().compute(image).flatten()
        features.append(feature)

    return np.array(features)

# 提取Haar特征
def extract_lbp_features(images):
    features = []

    for image in images:
        # a = 1
        feature = local_binary_pattern(image, 8, 1, method='uniform').flatten()
        features.append(feature)

    return np.array(features)

# 构建决策树分类器
def build_decision_tree_classifier(X_train, y_train):
    clf = DecisionTreeClassifier(max_depth=5)
    clf.fit(X_train, y_train)
    return clf

# 构建随机森林分类器
def build_random_forest_classifier(X_train, y_train):
    clf = RandomForestClassifier(n_estimators=50, max_depth=3)
    clf.fit(X_train, y_train)
    return clf

# 构建KNN分类器
def build_knn_classifier(X_train, y_train):
    clf = KNeighborsClassifier(n_neighbors=10,algorithm='kd_tree')
    clf.fit(X_train, y_train)
    return clf

# 构建SVM分类器
def build_svm_classifier(X_train, y_train):
    clf = SVC(kernel='linear', C=0.025)
    clf.fit(X_train, y_train)
    return clf

# 构建集成模型
def build_ensemble_model(X_train, y_train):
    clfs = []

    clf1 = build_decision_tree_classifier(X_train, y_train)
    clfs.append(('dt', clf1))

    clf2 = build_random_forest_classifier(X_train, y_train)
    clfs.append(('rf', clf2))

    clf3 = build_knn_classifier(X_train, y_train)
    clfs.append(('knn', clf3))

    clf4 = build_svm_classifier(X_train, y_train)
    clfs.append(('svm', clf4))

    eclf = VotingClassifier(estimators=clfs, voting='hard')
    
    # 输出训练过程
    for clf in [clf1,clf2,clf3,clf4, eclf]:
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        print(f'{clf.__class__.__name__} Accuracy: {acc}')

    eclf.fit(X_train, y_train)

    # 保存模型权重
    joblib.dump(eclf, target_weight_path)
    


    return eclf

# 对测试集进行预测
def predict(model, test_features):
    preds = model.predict(test_features)
    return preds

# 评估模型准确率
def evaluate(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    print(f'Accuracy: {acc}')

# 加载模型权重

clf = joblib.load(target_weight_path) if os.path.exists(target_weight_path)  else None

image_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),"data","handle",model_name)
# 读取数据集
images, labels = read_images(image_dir)
print(len(images))
# 提取特征
# features = extract_haar_features(images)
features = extract_lbp_features(images)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.25, random_state=42)

# 构建或加载集成模型
if clf is None:
    clf = build_ensemble_model(X_train, y_train)

# 对测试集进行预测
y_pred = predict(clf, X_test)

# 评估模型性能
evaluate(y_test, y_pred)

3.fast api 封装

待续,正在开发中。

4.总结

对于光的影响,转成灰度图是否有影响–这是一个大的话题?待续。
另外,用Haar提取特征的时候,权重文件会恐怖的达到12G,而且训练的时候内存会爆掉,这个问题耽误我小半天时间,原因未知。
后来改成lbp的方式就很小的,速度也很快。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/494611.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决Edge Dev更新后NewBing侧边栏消失的问题,并使用NewBing作画

文章目录 解决Edge Dev更新后NewBing侧边栏消失的问题,并使用NewBing作画问题来源操作步骤打开侧边栏步骤尝试让NewBing给出图像输出表情包或者其他图片使用NewBing作画 查看聊天记录插件 总结 解决Edge Dev更新后NewBing侧边栏消失的问题,并使用NewBing…

标签制作软件如何批量制作DotCode码

DotCode码是由不连续的点组成的二维条形码符号。设计的初衷是工业流水线上使用高速喷墨/激光打印机印刷产品有效期、批号以及序列号等。其尺寸是灵活可变的,可以根据货品表面的大小来调整印刷。下面带大家一起看一下在标签制作软件中如何批量制作: 打开…

STM32F4_十进制和BCD码的转换

目录 前言 1. BCD码 2. BCD码和十进制转换的算法 前言 最近在学习STM32单片机(不仅仅是32)的RTC实时时钟系统的过程中,需要配置时钟的时间、日期;这些都需要实现BCD码和十进制之间进行转换。这里和大家一起学习BCD码和十进制之…

C++函数必备简单知识

目录 1、函数的定义与声明 (1)定义 (2)声明 2、指针传参 3、引用 4、函数的引用传参 5、函数重载 overlord (1)参数数量不同 (2)参数类型不同 6、避免overlord歧义 7、内…

Opencv+Python图像像素处理

目录 二值图像的像素访问、修改 单个像素访问、修改 多个像素修改 彩色图像(三维数组) 像素访问、修改 BGR模式 像素访问、修改 二值图像的像素访问、修改 单个像素访问、修改 import numpy as np import cv2 as cv # 使用Numpy库中的函数zeros()可…

springboot登录验证

案例-登录认证 已经实现了部门管理、员工管理的基本功能,但是大家会发现,但没有登录,就直接访问到了Tlias智能学习辅助系统的后台。 这是不安全的,今天的主题就是登录认证。 最终要实现的效果就是用户必须登录之后,才…

Spark学习笔记【shuffle】

本文基本上是大数据处理框架Apache Spark设计与实现的Shuffle部分的学习。以及Spark基础知识Bambrow Shuffle解决啥问题 上游和下游,不同stage,不同的task之间是如何传递数据的。ShuffleManager管理ShuffleWrite和ShuffleRead 分为两个阶段&#xff1…

基于JavaWeb实现的寻码网文章资讯管理系统

一、技术结构 前端:html ajax 后端:SpringBootMybatis-plus 环境:JDK1.8 | Mysql | Maven | Redis 二、功能简介 数据库与代码截图 后端管理-登录页 后端管理-首页 后端管理-文章管理-发布文章 后端管理-文章管理-文章列表 后端管理-文…

Vue快速入门,常用指令,生命周期

Vue常用指令 案例&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"…

MySQL一次大量内存消耗的跟踪

GreatSQL社区原创内容未经授权不得随意使用&#xff0c;转载请联系小编并注明来源。GreatSQL是MySQL的国产分支版本&#xff0c;使用上与MySQL一致。文章来源&#xff1a;GreatSQL社区原创 线上使用MySQL8.0.25的数据库&#xff0c;通过监控发现数据库在查询一个视图(80张表的u…

【网络进阶】WebSocket协议

文章目录 1. Web实时技术的应用2. WebSocket协议介绍2.1 WebSocket的工作原理2.2 优点2.3. 使用场景2.4 实现细节 3. WebSocket服务器实现3.1 客户端代码&#xff08;HTML & JavaScript&#xff09;3.2 服务器端代码&#xff08;C&#xff09;3.3 测试结果 1. Web实时技术的…

Qt 智能指针介绍: QSharedPointer、QWeakPointer 、QScopedPointer 、QPointer(附实例)

文章目录 1. 概述2. Qt 中有几种智能指针&#xff1f;2.1 QSharedPointer 实例2.2 QSharedPointer 与 QWeakPointer 实例2.3 QScopedPointer 实例2.4 QPointer 实例 1. 概述 在使用动态内存分配的情况下&#xff0c;需要确保对象的所有权正确地被管理和转移。使用智能指针可以…

【HarmonyOS】【FAQ】HarmonyOS应用开发相关问题解答(一)

【前提简介】 本文档主要总结HarmonyOS开发过程中可能遇到的一些问题解答&#xff0c;主要围绕HarmonyOS展开&#xff0c;包括但不限于不同API版本HarmonyOS开发、UI组件、DevEco Studio、Gitee示例代码等&#xff0c;并将持续更新哦。 【官方FAQ】 【FAQ】HarmonyOS应用及服…

(十二)地理数据库创建——基本组成项及数据加载

地理数据库创建——基本组成项及数据加载 目录 地理数据库创建——基本组成项及数据加载 1.建立数据库中的基本组成项1.1建立要素数据集1.2建立要素类1.2.1建立简单要素类1.2.2建立关系表 1.3建立关系表 2.向地理数据库加载数据2.1导入数据2.1.1导入Shapefile2.1.2导入dBASE 表…

数据结构:顺序表的增、删,查、改实现

1.概念 顺序表是用一段 物理地址连续 的存储单元依次存储数据元素的线性结构&#xff0c;一般情况下采用数组存 储。在数组上完成数据的增删查改。 2.分类 顺序表一般可以分为&#xff1a; 2.1 静态顺序表&#xff1a;使用定长数组存储元素 这样会造成空间给多了浪费&#x…

ThreadLocal初探

一、ThreadLocal介绍 一、官方介绍 从Java官方文档中的描述&#xff1a;ThreadLocal类用来提供线程内部的局部变量&#xff0c;这种变量在多线程环境下访问&#xff08;通过get和set方法访问&#xff09;时&#xff0c;能够保证各个线程的变量相对独立于其他线程内的变量。Thr…

apple pencil必须要买吗?性价比平替电容笔排行榜

要知道&#xff0c;真正的苹果原装Pencil&#xff0c;价格实在是太贵了&#xff0c;普通的消费者根本买不起。所以&#xff0c;有没有可能出现一种平替的、功能一模一样的、与苹果Pencil一样的电容笔呢&#xff1f;这倒也是。国产的平替笔和苹果Pencil的笔相比&#xff0c;并没…

Wireless-Sensor-Network-master_WSN_无线传感网络(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 近年来&#xff0c;随着对等网络、云计算和网格计算等分布式环境的发展&#xff0c;无线传感器网络&#xff08;WSN&#xff0…

10分钟吃透,python操作mysql数据库的增、删、改、查

大家好&#xff0c;我是csdn的博主&#xff1a;lqj_本人 这是我的个人博客主页&#xff1a; lqj_本人的博客_CSDN博客-微信小程序,前端,python领域博主lqj_本人擅长微信小程序,前端,python,等方面的知识https://blog.csdn.net/lbcyllqj?spm1011.2415.3001.5343哔哩哔哩欢迎关注…

聊聊汽车OTA测试技术方案

汽车OTA已成为时下热门话题&#xff0c;由于OTA的升级可能会带来一定的风险&#xff0c;针对OTA的测试就尤为重要。本文我们主要通过介绍OTA的发展背景、汽车OTA测试的必要性以及汽车OTA测试内容&#xff0c;为大家分享一套成熟的OTA测试方案。 什么是OTA OTA&#xff08;Over-…