基础:用卷积神经网络(CNN)进行猫狗图像分类

news2024/12/22 23:35:43

在本篇教程中,我们将通过卷积神经网络(CNN)实现一个简单的猫狗图像分类器。我们将介绍如何处理数据、构建CNN模型、训练模型并在测试集上进行预测。最终,你将能够用这个模型对未知图像进行猫狗分类。
在这里插入图片描述

1. 环境准备

首先,确保你已经安装了以下库:

  • tensorflow(用于深度学习模型)
  • opencv-python(用于图像处理)
  • numpy(用于数值计算)
  • matplotlib(用于数据可视化)
  • tqdm(用于显示进度条)
pip install tensorflow opencv-python numpy matplotlib tqdm

2. 数据预处理

2.1 数据集结构

假设你已经准备好了猫狗分类的数据集,其中包含两个文件夹:train/test/。每个文件夹下包含多个 .jpg 格式的图像。训练集中的图像每个都对应一个标签,标签通过文件名中的catdog来标识。例如,cat.0.jpg代表一只猫,dog.1.jpg代表一只狗。

2.2 标签编码

为了将标签转换为机器学习模型可以处理的格式,我们使用独热编码(One-Hot Encoding)。具体来说,如果图像是猫,则标签为[1, 0];如果图像是狗,则标签为[0, 1]

def label_img(img):
    word_label = img.split('.')[-3]
    if word_label == 'cat':
        return [1, 0]
    elif word_label == 'dog':
        return [0, 1]

2.3 读取训练数据

我们将读取训练数据,将每个图像调整为固定大小,并转换为灰度图像。然后,将图像和标签组合成一个训练数据集,并进行随机打乱。

def create_train_data():
    training_data = []
    for img in tqdm(os.listdir(train_dir)):
        if not img.endswith('.jpg'):
            continue
        label = label_img(img)
        path = os.path.join(train_dir, img)
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)  # 读取灰度图
        img = cv2.resize(img, (img_size, img_size))  # 调整图像大小
        training_data.append([np.array(img), np.array(label)])
    shuffle(training_data)
    return training_data

2.4 读取测试数据

测试数据与训练数据相似,但是没有标签。我们仅将测试数据中的图像加载并调整为固定大小。

def process_test_data():
    testing_data = []
    for img in tqdm(os.listdir(test_dir)):
        if not img.endswith('.jpg'):
            continue
        path = os.path.join(test_dir, img)
        img_num = img.split('.')[0]
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (img_size, img_size))
        testing_data.append([np.array(img), img_num])
    shuffle(testing_data)
    return testing_data

3. 构建卷积神经网络(CNN)

接下来,我们将构建一个卷积神经网络来进行图像分类。该网络由多个卷积层、池化层和全连接层组成。

3.1 定义模型结构

model = Sequential()

# 输入层和卷积层
model.add(Conv2D(32, (5, 5), activation='relu', input_shape=(img_size, img_size, 1), padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))

# 添加更多卷积层和池化层
model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (5, 5), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (5, 5), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (5, 5), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))

# 全连接层
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.8))

# 输出层
model.add(Dense(2, activation='softmax'))

3.2 编译模型

使用Adam优化器,损失函数采用categorical_crossentropy,因为这是一个多类分类问题。

model.compile(optimizer=Adam(lr), loss='categorical_crossentropy', metrics=['accuracy'])

4. 准备数据并训练模型

4.1 划分训练和验证集

我们将训练数据分为训练集和验证集。

train = train_data[:-500]
test = train_data[-500:]

X = np.array([i[0] for i in train], dtype=np.float64).reshape(-1, img_size, img_size, 1)
y = np.array([i[1] for i in train], dtype=np.float64)
Xtest = np.array([i[0] for i in test], dtype=np.float64).reshape(-1, img_size, img_size, 1)
ytest = np.array([i[1] for i in test], dtype=np.float64)

4.2 训练模型

我们使用fit方法训练模型,并将验证数据传入以监控验证集上的性能。

model.fit(X, y, epochs=3, validation_data=(Xtest, ytest), batch_size=32, verbose=1)

5. 在测试集上进行预测

5.1 加载测试数据并进行预测

我们将加载测试数据并对每个图像进行分类预测。

test_data = process_test_data()

# 可视化预测结果
fig = plt.figure()
for num, data in enumerate(test_data[:16]):
    img_num = data[1]
    img_data = data[0]
    
    y = fig.add_subplot(4, 4, num + 1)
    orig = img_data
    data = img_data.reshape(1, img_size, img_size, 1)
    
    model_out = model.predict(data)[0]
    
    if np.argmax(model_out) == 1:
        label = 'Dog'
    else:
        label = 'Cat'
    
    y.imshow(orig, cmap='gray')
    plt.title(label)
    y.axes.get_xaxis().set_visible(False)
    y.axes.get_yaxis().set_visible(False)

plt.tight_layout()
plt.show()

5.2 可视化输出

使用matplotlib库,我们可以将模型对测试集的预测可视化,直观地查看模型的分类效果。
在这里插入图片描述

6. 总结

通过本教程,你已经学会了如何使用卷积神经网络进行猫狗图像分类。我们涵盖了数据加载、预处理、CNN模型构建、训练及评估,并展示了如何在测试数据上进行预测。

你可以根据自己的需求调整网络结构、优化器和超参数,以获得更好的分类效果。

需要数据集的添加

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

node对接ChatGpt的流式输出的配置

node对接ChatGpt的流式输出的配置 首先看一下效果 将数据用流的方式返回给客户端,这种技术需求在传统的管理项目中不多见,但是在媒体或者有实时消息等功能上就会用到,这个知识点对于前端还是很重要的。 即时你不写服务端,但是服务端如果给你这样的接口,你也得知道怎么去使用联…

聊聊Flink:Flink的运行时架构

一、运行时架构 上一篇我们可以看到Flink的核心组件的Deploy层,该层主要涉及了Flink的部署模式,Flink支持多种部署模式:本地、集群(Standalone/YARN)、云(GCE/EC2)。 Local(本地&am…

【动手学电机驱动】 STM32-FOC(7)MCSDK Pilot 上位机控制与调试

STM32-FOC(1)STM32 电机控制的软件开发环境 STM32-FOC(2)STM32 导入和创建项目 STM32-FOC(3)STM32 三路互补 PWM 输出 STM32-FOC(4)IHM03 电机控制套件介绍 STM32-FOC(5&…

华为云前台用户可挂载数据盘和系统盘是怎么做到的?

用户可以选择磁盘类型和容量,其后台是管理员对接存储设备 1.管理员如何在后台对接存储设备(特指业务存储) 1.1FusionSphere CPS(Cloud Provisionivice)云装配服务 它是first node https://10.200.4.159:8890 对接存…

Python爬虫知识体系-----requests-----持续更新

数据科学、数据分析、人工智能必备知识汇总-----Python爬虫-----持续更新:https://blog.csdn.net/grd_java/article/details/140574349 文章目录 一、安装和基本使用1. 安装2. 基本使用3. response常用属性 二、get请求三、post请求四、代理 一、安装和基本使用 1.…

区块链技术在数据安全中的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 区块链技术在数据安全中的应用 区块链技术在数据安全中的应用 区块链技术在数据安全中的应用 引言 区块链技术基础 1.1 区块链的…

RK3568平台开发系列讲解(GPIO篇)GPIO的sysfs调试手段

🚀返回专栏总目录 文章目录 一、内核配置二、GPIO sysfs节点介绍三、命令行控制GPIO3.1、sd导出GPIO3.2、设置GPIO方向3.3、GPIO输入电平读取3.4、GPIO输出电平设置四、Linux 应用控制GPIO4.1、控制输出4.2、输入检测4.3、使用 GPIO 中断沉淀、分享、成长,让自己和他人都能有…

电商系统开发:Spring Boot框架实战

3 系统分析 当用户确定开发一款程序时,是需要遵循下面的顺序进行工作,概括为:系统分析–>系统设计–>系统开发–>系统测试,无论这个过程是否有变更或者迭代,都是按照这样的顺序开展工作的。系统分析就是分析系…

从电动汽车到车载充电器:LM317LBDR2G 线性稳压器在汽车中的多场景应用

附上LM317系列选型: LM317BD2TG-TO-263 LM317BTG-TO-220 LM317BD2TR4G-TO-263 LM317D2TG-TO-263 LM317D2TR4G-TO-263 LM317TG-TO-220 LM317LBDR2G-SOP-8 LM317LDR2G-SOP-8 LM317MABDTG-TO-252 LM317MABDTRKG-TO-252 LM317MA…

Linux下MySQL的简单使用

Linux下MySQL的简单使用 导语MySQL安装与配置MySQL安装密码设置 MySQL管理命令myisamchkmysql其他 常见操作 C语言访问MYSQL连接例程错误处理使用SQL 总结参考文献 导语 这一章是MySQL的使用,一些常用的MySQL语句属于本科阶段内容,然后是C语言和MySQl之…

前端 JS 实用操作总结

目录 1、重构解构 1、数组解构 2、对象解构 3、...展开 2、箭头函数 1、简写 2、this指向 3、没有arguments 4、普通函数this的指向 3、数组实用方法 1、map和filter 2、find 3、reduce 1、重构解构 1、数组解构 const arr ["唐僧", "孙悟空&quo…

力扣 LeetCode 541. 反转字符串II(Day4:字符串)

解题思路&#xff1a; i可以成段成段的跳&#xff0c;而不是简单的i class Solution {public String reverseStr(String s, int k) {char[] ch s.toCharArray();// 1. 每隔 2k 个字符的前 k 个字符进行反转for (int i 0; i < ch.length; i 2 * k) {// 2. 剩余字符小于 …

鸿蒙版APP-图书购物商城案例

鸿蒙版-小麦图书APP是基于鸿蒙ArkTS-API12环境进行开发&#xff0c;不包含后台管理系统&#xff0c;只有APP端&#xff0c;页面图书数据是从第三方平台(聚合数据)获取进行展示的&#xff0c;包含登录&#xff0c;图书类别切换&#xff0c;图书列表展示&#xff0c;图书详情查看…

卡尔曼滤波:从理论到应用的简介

卡尔曼滤波&#xff08;Kalman Filter&#xff09;是一种递归算法&#xff0c;用于对一系列噪声观测数据进行动态系统状态估计。它广泛应用于导航、控制系统、信号处理、金融预测等多个领域。本文将介绍卡尔曼滤波的基本原理、核心公式和应用案例。 1. 什么是卡尔曼滤波&#x…

学习日志011--模块,迭代器与生成器,正则表达式

一、python模块 在之前学习c语言时&#xff0c;我们学了分文件编辑&#xff0c;那么在python中是否存在类似的编写方式&#xff1f;答案是肯定的。python中同样可以实现分文件编辑。甚至还有更多的好处&#xff1a; ‌提高代码的可维护性‌&#xff1a;当代码被分成多个文件时…

CSS 语法规范

基本语法结构 CSS 的基本语法结构包含 选择器 和 声明块,两者共同组成 规则集。规则集可以为 HTML 元素设置样式,使页面结构和样式实现分离,便于网页的美化和布局调整。 CSS 规则集的结构如下: selector {property: value; }选择器(Selector) 选择器用于指定需要应用…

Bag Graph: Multiple Instance Learning Using Bayesian Graph Neural Networks文献笔记

基本信息 原文链接&#xff1a;[2202.11132] Bag Graph: Multiple Instance Learning using Bayesian Graph Neural Networks 方法概括&#xff1a;用图&#xff08;贝叶斯GNN框架&#xff09;来建模袋之间的相互作用&#xff0c;并使用图神经网络&#xff08;gnn&#xff09…

Spark 共享变量:广播变量与累加器解析

Spark 的介绍与搭建&#xff1a;从理论到实践_spark环境搭建-CSDN博客 Spark 的Standalone集群环境安装与测试-CSDN博客 PySpark 本地开发环境搭建与实践-CSDN博客 Spark 程序开发与提交&#xff1a;本地与集群模式全解析-CSDN博客 Spark on YARN&#xff1a;Spark集群模式…

前海华海金融创新中心的工地餐点探寻

​前海的工地餐大部分都是13元一份的哈。我在前海华海金融创新中心的工地餐点吃过一份猪杂饭&#xff0c;现做13元一份。我一般打包后回公司吃或直接桂湾公园找个环境优美的地方吃饭。 ​我点的这份猪杂汤粉主要是瘦肉、猪肝、肉饼片、豆芽和生菜&#xff0c;老板依旧贴心问需要…

借助Excel实现Word表格快速排序

实例需求&#xff1a;Word中的表格如下图所示&#xff0c;为了强化记忆&#xff0c;希望能够将表格内容随机排序&#xff0c;表格第一列仍然按照顺序编号&#xff0c;即编号不跟随表格行内容调整。 乱序之后的效果如下图所示&#xff08;每次运行代码的结果都不一定相同&#x…