深度学习——图像分类(CNN)—训练模型

news2024/10/7 20:36:11

训练模型

    • 1.导入必要的库
    • 2.定义超参数
    • 3.读取训练和测试标签CSV文件
    • 4.确保标签是字符串类型
    • 5.显示两个数据框的前几行以了解它们的结构
    • 6.定义图像处理参数
    • 7.创建图像数据生成器
    • 8.设置目录路径
    • 9.创建训练和验证数据生成器
    • 10.构建模型
    • 11.编译模型
    • 12.训练模型并收集历史
    • 13.绘制损失和准确率曲线
    • 14.保存图表
    • 15.保存模型到本地

1.导入必要的库

pandas as pd: Pandas是一个强大的数据分析和处理库,它提供了数据结构(如DataFrame)和工具,用于数据操作和分析。
tensorflow.keras.preprocessing.image import ImageDataGenerator: ImageDataGenerator是Keras的一部分,它用于图像数据的预处理和增强,例如,随机裁剪、旋转、缩放等。
tensorflow.keras.models import Sequential: Sequential模型是Keras中的一种模型,它允许您顺序地堆叠层。
tensorflow.keras.layers: 包含了Keras中所有的层类型,如Conv2D、MaxPooling2D、Flatten、Dense等。
tensorflow.keras.optimizers: 包含了Keras中所有的优化器类型,如Adam、SGD等。
sklearn.model_selection import train_test_split: train_test_split是Scikit-Learn的一部分,它用于将数据集分割为训练集和测试集。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
sklearn.preprocessing import LabelBinarizer: LabelBinarizer是Scikit-Learn的一部分,它用于将类别标签转换为二进制数组。
matplotlib.pyplot as plt: Matplotlib是一个绘图库,pyplot是其中的一个模块,它提供了一个类似于MATLAB的绘图框架。
import pickle: pickle是Python的标准库,它用于序列化Python对象,以便将它们保存到文件或从文件中加载。

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle

2.定义超参数

INIT_LR = 0.01
EPOCHS = 30
BS = 32

3.读取训练和测试标签CSV文件

train_labels.csv和test_labels.csv在资源中。

# 读取训练标签CSV文件
train_labels_filename = 'train_labels.csv'
train_labels_df = pd.read_csv(train_labels_filename)

# 读取测试标签CSV文件
test_labels_filename = 'test_labels.csv'
test_labels_df = pd.read_csv(test_labels_filename)

4.确保标签是字符串类型

train_labels_df[‘label’] = train_labels_df[‘label’].astype(str):

train_labels_df['label']:这是train_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

test_labels_df[‘label’] = test_labels_df[‘label’].astype(str):

test_labels_df['label']:这是test_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

train_labels_df['label'] = train_labels_df['label'].astype(str)
test_labels_df['label'] = test_labels_df['label'].astype(str)

5.显示两个数据框的前几行以了解它们的结构

print(train_labels_df.head())
print(test_labels_df.head())

6.定义图像处理参数

img_width:这是一个变量,用于存储图像的宽度。
img_height:这是一个变量,用于存储图像的高度。
= 150, 150:这行代码将img_width和img_height变量分别设置为150。

img_width, img_height = 150, 150

7.创建图像数据生成器

ImageDataGenerator:这是Keras中的一个类,用于创建一个数据生成器,用于图像数据的增强和预处理。
rescale=1./255:这是一个参数,用于将图像的像素值从0到255的范围转换为0到1的范围,这是常见的图像预处理步骤。
validation_split=0.2:这是一个参数,用于指定训练数据中用于验证的比例。在这里,20%的数据将用于验证,80%的数据将用于训练。
data_gen:这是生成的ImageDataGenerator对象,它将在后续的训练过程中用于生成增强的图像数据。

data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

8.设置目录路径

train和test压缩文件在资源中

# 并且数据集应该存储在环境可访问的路径中
train_dir = 'D:/rgzn/face/DATASET/train'  # 包含子文件夹的父目录
test_dir = 'D:/rgzn/face/DATASET/test'    # 包含子文件夹的父目录

9.创建训练和验证数据生成器

#flow_from_dataframe:这是Keras中的一个方法,用于创建一个数据生成器,它可以从DataFrame中加载图像和标签。
train_data_gen = data_gen.flow_from_dataframe(

#要加载的数据源
dataframe=train_labels_df,
#包含图像文件的目录
directory=train_dir,  
#DataFrame中包含图像路径的列名。
x_col='image',
#DataFrame中包含标签的列名。
y_col='label',
#目标图像的大小
target_size=(img_width, img_height),
#每次迭代中从数据生成器中获取的样本数量。
batch_size=32,
#随机种子,用于确保每次运行时生成相同的数据增强
seed=42,
#数据集的子集,用于训练。
    subset='training',
)
validation_data_gen = data_gen.flow_from_dataframe(
    dataframe=test_labels_df,
    directory=test_dir,  # 包含子文件夹的父目录
    x_col='image',
    y_col='label',
    target_size=(img_width, img_height),
    batch_size=32,
seed=42,
#数据集的子集,用于验证。
    subset='validation',
)

10.构建模型

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))

# 新增的卷积层
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

# 展平层
model.add(Flatten())

# 全连接层
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))

# 输出层
model.add(Dense(7, activation='softmax'))

11.编译模型

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

model:这是之前创建和配置的Keras模型。
compile:这是Keras中的一个方法,用于编译模型,指定训练过程中使用的损失函数、优化器和评估指标。
loss='categorical_crossentropy':这是模型使用的损失函数,适用于多类分类问题。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

12.训练模型并收集历史

history = model.fit(train_data_gen, epochs=EPOCHS, validation_data=validation_data_gen, batch_size=BS)

fit:这是Keras中的一个方法,用于训练模型。
train_data_gen:这是之前创建的训练数据生成器。
epochs=EPOCHS:这是训练过程中重复训练数据的次数。
validation_data=validation_data_gen:这是用于验证模型的数据。
batch_size=BS:这是每次迭代中从数据生成器中获取的样本数量。
history:这是训练过程中记录的性能指标,如损失和准确率。

13.绘制损失和准确率曲线

N = np.arange(0, EPOCHS)
#设置图表的样式
plt.style.use('ggplot')
plt.figure()

plt.plot(N, history.history['loss'], label='train_loss')
plt.plot(N, history.history['val_loss'], label='val_loss')
plt.plot(N, history.history['accuracy'], label='train_acc')
plt.plot(N, history.history['val_accuracy'], label='val_acc')

plt.title("Training Loss And Accuracy (CNN)")
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.axis([0, EPOCHS, 0, 2])

14.保存图表

plt.savefig('plot.png')

15.保存模型到本地

print('[INFO] 正在保存模型')
model.save('model.h5')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1686688.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

145.栈和队列:删除字符串中的所有相邻重复项(力扣)

题目描述 代码解决 class Solution { public:string removeDuplicates(string s) {// 定义一个栈来存储字符stack<char> st;// 遍历字符串中的每一个字符for(int i 0; i < s.size(); i){// 如果栈为空或栈顶字符与当前字符不相同&#xff0c;则将当前字符入栈if(st.e…

SpringBoot项目中redis序列化和反序列化LocalDateTime失败

实体类中包含了LocalDateTime 类型的属性&#xff0c;把实体类数据存入Redis后变成这样&#xff1a; 此时&#xff0c;存入redis不会报错&#xff0c;但是从redis获取的时候&#xff0c;会报错&#xff1a; com.fasterxml.jackson.databind.exc.InvalidDefinitionException: Ca…

AI - Transformer架构工作原理

一、概述 Transformer是由Vaswani等人在2017年提出的一种基于自注意力机制&#xff08;Self-Attention Mechanism&#xff09;的深度学习网络架构的大模型&#xff0c;被广泛应用于自然语言处理&#xff08;NLP&#xff09;领域&#xff0c;如机器翻译、文本生成等任务。它摒弃…

LabVIEW高温往复摩擦测试系统中PID控制

在LabVIEW开发高温往复摩擦测试系统中实现PID控制&#xff0c;需要注意以下几个方面&#xff1a; 1. 系统建模与参数确定 物理模型建立: 首先&#xff0c;需要了解被控对象的物理特性&#xff0c;包括热惯性、摩擦系数等。这些特性决定了系统的响应速度和稳定性。实验数据获取…

陕西煤矿化工集团如何投稿刊登到央媒

随着信息技术的飞速发展&#xff0c;国家级媒体平台已经成为了众多作者追求发表文章的热门选择。然而&#xff0c;要想在这些平台上成功发表文章&#xff0c;除了具备优秀的文稿质量外&#xff0c;还需要掌握一定的投稿技巧和策略。本文将为您详细介绍国家级媒体投稿方式&#…

samba_ubuntu_share_vmbox_vmware

_____ Ubuntu 利用 samba 与 win 直接共享文件夹 _____ samba Samba - 维基百科&#xff0c;自由的百科全书 (wikipedia.org) 用于 win 和 unix 直接访问资源 samba 为选定的 unix 目录建立网络共享&#xff0c; 使得 win 用户可以像访问普通 win 下的文件夹那样来通过网络来…

Discourse Discover 添加你的网站到 Discourse 官方

discourse discover 应该允许你把你的 Discourse 实例添加到 Discourse 的 https://discover.discourse.org/ 1 页面中。 直接在你网站的配置上搜索 Discourse Discover &#xff0c;余下的工作就可以交给 Discourse 了。 还没有选的&#xff0c;可以马上选上喔。 但显然排序…

Baidu Comate For Xcode 你的AI编程助手

前言 Baidu Comate 基于文心大模型&#xff0c;结合百度编程大数据&#xff0c;为你生成优质编程代码 你的AI编程助手&#xff0c;你的编码效率提升好帮手 Baidu Comate 释放“十倍”软件生产力 一、Xcode 安装配置 Baidu Comate 安装 已安装Xcode的情况下&#xff0c;下载B…

Windows下安装Hadoop(引导版)

Windows下安装Hadoop(引导版) 本环境只作为测试环境的搭建和学习使用 参考文档 环境&#xff1a; 首先确定环境为java1.8 或者hadoop适配的版本 cmd java -version查看 hadoop环境变量可以不用设置 关于hdfs的配置可以自行修改目录 具体的安装方式参考下面的两个文档 下载…

Spring Boot企业级开发教程-第4章Spring Boot视图技术

文章目录 4.1 Spring Boot支持的视图技术4.2 Thymaleaf基本语法常用标签标准表达式1.变量表达式2.选择变量表达式3.消息表达式4.链接表达式5.片段表达式 4.3 Thymaleaf基本使用4.3.1 Thymeleaf模板基本配置4.3.2 静态资源的访问 4.4 使用Thymaleafs完成页面的数据展示4.5 使用T…

Android面试题之Kotlin常见集合操作技巧

本文首发于公众号“AntDream”&#xff0c;欢迎微信搜索“AntDream”或扫描文章底部二维码关注&#xff0c;和我一起每天进步一点点 list 创建和修改 不可变list,listOf var list listOf("a","d","f") println(list.getOrElse(3){"Unkn…

NebulaGraph

文章目录 关于 NebulaGraph客户端支持安装 NebulaGraph关于 nGQLnGQL 可以做什么2500 条 nGQL 示例原生 nGQL 和 openCypher 的关系 Backup&Restore功能 导入导出导入工具导出工具 NebulaGraph ImporterNebulaGraph ExchangeNebulaGraph Spark ConnectorNebulaGraph Flink …

临时工说:为什么成熟的数据库企业都在云上部署产品,并把主要力量放到云上...

开头还是介绍一下群&#xff0c;如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis, Oceanbase, Sql Server等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;&#xff08;共2320人左右 1 …

网站笔记:huggingface——can you run it?

Can You Run It? LLM version - a Hugging Face Space by Vokturz 1 配置设置部分 Model Name就是需要测量的模型名称 GPU Vendor ——GPU供应商 Filter by RAM (按RAM过滤) 筛选出所有内存容量在选择范围之间的GPU GPU 下拉菜单选择具体的GPU型号 LoRa % trainable param…

如何用VSCode debug Python文件

诸神缄默不语-个人CSDN博文目录 需求&#xff1a;我其实一般都用print大法来“调试”程序&#xff0c;但是有时对于机械性比较强但是又有些复杂的程序&#xff0c;还是debug比较方便。 debug功能我之前用过NetBeans和eclipse&#xff0c;应该可以明显看出来我是Java转Python党…

做好智慧校园的顶层设计,助力教育信息化发展

教育信息化已被视为我国教育事业发展的重要支撑。随着国家教育信息化一系列重大工程的部署和实施&#xff0c;我国教育信息化进入快速发展时期&#xff0c;取得了显著成绩。我们认识到国家教育信息化正由初步应用融合阶段向着全面融合创新阶段过度&#xff0c;无论从国家地区的…

上位机图像处理和嵌入式模块部署(mcu之芯片选择)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 目前市面上的mcu很多&#xff0c;有国产的&#xff0c;有进口的&#xff0c;总之种类很多。以stm32为例&#xff0c;这里面又包括了stm32f1、stm32…

local dimming(局部调光)介绍

文章目录 1. 什么是local dimming2. 工作原理3. 类型4. 优点5. 缺点和局限7. 技术发展趋势 1. 什么是local dimming local dimming&#xff08;局部调光&#xff09;是电视和显示器中用于提升画面对比度和画质的背光技术。其基本原理是将背光源&#xff08;通常是LED&#xff…

python写接口性能测试

import time import requestsdef measure_response_time(api_url):try:start_time time.time()response requests.get(api_url, timeout10) # 设置超时时间为10秒end_time time.time()response_time end_time - start_timeprint(f"接口 {api_url} 的响应时间为&#…

UE5 OnlineSubsystem Steam创建会话失败解决方法

连接上Steam但是创建会话失败 解决方法 在DefaultEngine.ini中加上bInitServerOnClienttrue,这个其实在官方文档里用注释给出了&#xff0c;直接取消注释就行 删除项目目录中的Saved、Internmediate、Binaries目录 右键你的项目.uproject选择Generate Visual Studio project f…