深度学习——自己的训练集——训练模型(CNN)

news2024/11/17 1:33:33

训练模型

    • 1.导入必要的库
    • 2.加载类别名称
    • 3.创建标签映射字典
    • 4.加载图像数据和对应的标签
    • 5.构建和编译CNN模型
    • 6.训练模型
    • 7.保存训练好的模型

1.导入必要的库

导入处理数据和训练模型时需要的库
os: 这个模块提供了与操作系统交互的功能,比如文件和目录操作。
cv2: 这是OpenCV库的别名,它是一个强大的计算机视觉库,用于图像和视频处理。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
tensorflow as tf: TensorFlow是一个开源的机器学习库,用于构建和训练各种类型的机器学习模型。

import os
import cv2
import numpy as np
import tensorflow as tf

2.加载类别名称

with open(‘99/classes.txt’, ‘r’) as f:

with open(...) as f::这是上下文管理器(context manager),用于自动处理文件资源的打开和关闭。当with语句执行完成后,文件会自动关闭,即使遇到异常也是如此。
'99/classes.txt':这是要打开的文件的路径。
'r':这是文件打开模式,表示以只读方式打开文件。
f:这是上下文管理器创建的文件对象,可以用来读取文件内容。

classes = f.read().splitlines():

f.read():这个方法调用用于读取文件的全部内容,并将结果作为一个字符串返回。
.splitlines():这个方法调用用于将字符串按照行分隔符(通常是换行符\n)分割成一个列表。
classes:这个变量存储了分割后的列表,其中每个元素都是一个从文件中读取的标签名称。

with open('99/classes.txt', 'r') as f:
    classes = f.read().splitlines()

3.创建标签映射字典

创建了一个标签映射字典,用于将标签索引转换为实际的标签名称。

label_mapping = {
    '0': 'sad',
    '1': 'happy',
    '2': 'amazed',
    '3': 'anger'
}

4.加载图像数据和对应的标签

从文件夹中加载了图像数据和对应的标签。

image_folder = '561'
label_folder = '99'

X_train = []
y_train = []

#遍历image_folder文件夹中的所有文件
for image_file in os.listdir(image_folder):
#创建一个完整的文件路径,将image_folder目录的路径和image_file(文件或子目录的名称)连接起来。
image_path = os.path.join(image_folder, image_file)

#cv2.imread(image_path):这个函数调用用于读取图像文件。
image = cv2.imread(image_path)

#如果图像成功加载,将图像数据添加到X_train列表中。
    if image is not None:
        X_train.append(image)

	#将label_folder目录的路径和image_file(去除.jpg扩展名后的文件名)连接起来,并在最后加上.txt扩展名
        label_file = os.path.join(label_folder, image_file.replace('.jpg', '.txt'))
        with open(label_file, 'r') as f:
            label_index = f.readline().strip().split()[0]  # 只取第一个数字作为标签索引
            label_name = label_mapping[label_index]
            label = classes.index(label_name)
            y_train.append(label)

X_train = np.array(X_train)
y_train = np.array(y_train)

5.构建和编译CNN模型

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(image.shape)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(len(classes), activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

model = tf.keras.Sequential([ … ]):

tf.keras.Sequential:这是一个Keras模型,用于创建一个包含顺序堆叠的层的模型。
[ ... ]:这是一个列表,其中包含了模型中的层。
model:这个变量存储了创建的Keras Sequential模型。

tf.keras.layers.Conv2D(32, (3, 3), activation=‘relu’,
input_shape=(image.shape)):

tf.keras.layers.Conv2D:这是一个2D卷积层,用于提取图像的局部特征。
32:这是卷积层的输出通道数。
(3, 3):这是卷积核的大小,即每个卷积核覆盖的像素区域。
activation='relu':这是激活函数,用于在每个卷积层之后应用。
input_shape=(image.shape):这是输入层的形状,它是从image.shape获得的,确保模型的输入形状与图像数据的形状匹配。

tf.keras.layers.MaxPooling2D((2, 2)):

tf.keras.layers.MaxPooling2D:这是一个2D最大池化层,用于通过取每个池化区域的最大值来减小特征图的大小。
(2, 2):这是池化窗口的大小,即每个池化操作覆盖的像素区域。

tf.keras.layers.Flatten():

tf.keras.layers.Flatten:这是一个扁平化层,用于将2D或多维数组展平为一维数组。

tf.keras.layers.Dense(128, activation=‘relu’):

tf.keras.layers.Dense:这是一个全连接层,用于在模型中添加更多的非线性变换。
128:这是全连接层的神经元数量。
activation='relu':这是激活函数,用于在每个全连接层之后应用。

tf.keras.layers.Dense(len(classes), activation=‘softmax’):

tf.keras.layers.Dense:这是一个全连接层,用于在模型中添加更多的非线性变换。
len(classes):这是全连接层的神经元数量,它等于类别的数量。
activation='softmax':这是激活函数,用于在每个全连接层之后应用,以产生一个概率分布。

model.compile(optimizer=‘adam’,
loss=‘sparse_categorical_crossentropy’, metrics=[‘accuracy’]):

model.compile:这个方法用于编译模型,指定训练过程中使用的优化器、损失函数和评估指标。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
loss='sparse_categorical_crossentropy':这是模型使用的损失函数,用于评估模型在训练数据上的性能。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

6.训练模型

model.fit(X_train, y_train, epochs=20, batch_size=32)

model.fit:这是Keras中的一个方法,用于训练模型。
X_train:这是模型的输入数据,它是一个NumPy数组。
y_train:这是模型的目标数据,它是一个NumPy数组。
epochs=20:这是训练过程中重复训练数据的次数。
batch_size=32:这是每次梯度更新的样本数量。

7.保存训练好的模型

model.save('cnn_model.h5')

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1700503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何选择优质的气膜体育馆工程服务商—轻空间

随着现代生活的便利化和时代感的增强,气膜体育馆成为越来越多人的选择。这种美观实用的建筑在学校、社区和体育中心等地广泛应用。许多投资者和客户都有意建造气膜体育馆,但在选择工程服务商时,往往面临困惑。以下几点将帮助您做出明智的选择…

grafana大盘展示node_expod节点

node_expod添加lables标签 Prometheus查询 语句查询 node_exporter_build_infografna添加变量查询 正常有值 切换其他的是有值的 我的报错原因 因为有多个数据源,我选择错了,因为修改的lable标签是其他数据源,所以获取不到 查询语句 我的变量是 $app node_filesyste…

【优选算法】位运算 {位运算符及其优先级;位运算的应用:判断位,打开位,关闭位,转置位,位图,get lowbit,close lowbit;相关编程题解析}

一、位运算符及其优先级 我们知道,计算机中的数在内存中都是以二进制形式进行存储的 ,而位运算就是直接对整数在内存中的二进制位进行操作,因此其执行效率非常高,在程序中尽量使用位运算进行操作,这会大大提高程序的性…

SpringBoot 上传文件示例

示例效果&#xff1a; 前端代码&#xff1a; <html> <head><title>上传文件示例</title></head> <body> <h2>方式一&#xff1a;普通表单上传</h2> <form action"/admin/upload" method"post" enctyp…

【Python】 Django 框架如何支持百万级日访问量

基本原理 Django 是一个高级的 Python Web 框架&#xff0c;它鼓励快速开发和干净、实用的设计。Django 遵循 MVC&#xff08;模型-视图-控制器&#xff09;设计模式&#xff0c;允许开发者通过编写更少的代码来构建高质量的 Web 应用程序。Django 自带了许多内置功能&#xf…

Python-opencv通过距离变换提取图像骨骼

文章目录 距离变换distanceTransform函数 距离变换 如果把二值图像理解成地形&#xff0c;黑色表示海洋&#xff0c;白色表示陆地&#xff0c;那么陆地上任意一点&#xff0c;到海洋都有一个最近的距离&#xff0c;如下图所示&#xff0c;对于左侧二值图像来说&#xff0c;【d…

嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻嘻

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和…

golang中的md5、sha256数据加密文件md5/sha256值计算步骤和运行内存图解

在go语言中对数据计算一个md5&#xff0c;或sha256和其他语言 如java, php中的使用方式稍有不同&#xff0c; 那就是要加密的数据必须通过流的形式写入到你创建的Hash对象中。 Hash数据加密步骤 1. 先使用对应的加密算法包中的New函数创建一个Hash对象&#xff0c;(这个也就是…

linux经典定时任务

在使用时记得替换为自己的脚本路径。请在相应的脚本第一行加上#!/bin/bash&#xff0c;否则脚本在定时任务中无法执行。 1、在每天凌晨2点执行 0 2 * * * /bin/sh bashup.sh 2、每天执行两次 下面的示例命令将在每天上午5点和下午5点执行。您可以通过逗号分隔指定多个时间戳…

vue中的$nextTick和过渡与动画

一.vue中的$nextTick 简述与用法&#xff1a;这是一个生命周期钩子 1.语法&#xff1a;this.$nextTick(回调函数) 2.作用&#xff1a;在下一次DOM更新结束后执行其指定的回调 3.什么时候用&#xff1a;当修改数据后&#xff0c;要基于更新后的新dom进行某些操作时&#xff0c;…

vue.js基础组件4--下

1.动态组件 1.定义动态组件 利用动态组件可以动态切换页面中显示的组件。使用<component>标签可以定义动态组件&#xff0c;语法格式如下。 <component is"要渲染的组件"></component>上述语法格式中&#xff0c;<component>标签必须配合i…

爬山算法全解析:掌握优化技巧,攀登技术高峰!

一、引言 爬山算法是一种局部搜索算法&#xff0c;它基于当前解的邻域中进行搜索&#xff0c;通过比较当前解与邻域解的优劣来更新当前解&#xff0c;从而逐步逼近最优解。本文将对爬山算法进行详细的介绍。 二、爬山算法简介 爬山算法是一种基于贪心策略的优化算法&#xff…

Modular military character

角色具有31个模块化骨架网格,每个模块具有多个蒙皮: 3个头(4skins) 3件衬衫(9skins) 3条裤子(9skins) 3只靴子(9skins) 7件战术背心(3skins) 4只手和手臂(2skins) 3顶帽子和头盔(9skins) 2个背包(3skins) 3支步枪(3skins) 模块允许您组装超过200万个不同的…

SW手势定义

crtle:独立; T:测量;R隐藏;视图>用户界面>动态显示父子关系 crtld:相同零件; alte:草图显示; altw:基准面显示; ALTZ:上一视图;

Docker笔记-一种在非交互式方式中加载环境变量的方法

背景 我遇到的现象是这样的&#xff0c;我在docker安装了dm python的客户端&#xff0c;但dm python实际上是对libdmdpi.so的调用&#xff0c;在交互式环境中&#xff08;/bin/bash&#xff09;调用python 连接达梦是没有任何问题的&#xff0c;但在非交互环境直接调用&#x…

Mac配置node环境

1.下载nvm(node版本管理工具&#xff0c;同Anaconda对Python的关系&#xff09;。 curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash 2.配置vi ~/.zshrc文件&#xff0c;添加如下配置&#xff1a; export NVM_DIR"$HOME/.nvm" [ -…

软考 系统架构设计师系列知识点之SOME/IP与DDS(1)

本文内容参考&#xff1a; 车载以太网 - SOME/IP简介_someip-CSDN博客 https://zhuanlan.zhihu.com/p/369422441 什么是SOME/IP?_someip-CSDN博客 SOME/IP 详解系列&#xff08;1&#xff09;—— 概述_some ip-CSDN博客 深入浅出SOME/IP协议&#xff1a;基本概念和原理-…

酷开科技相伴童年 | 酷开系统六一特辑:亲子共赏,启迪成长

六一儿童节&#xff0c;属于每个茁壮成长的孩子&#xff0c;也属于每个童心未泯的“少年”。《小王子》里说&#xff0c;使生活如此美丽的是我们藏起来的真诚和童心。马上就到六一儿童节了&#xff0c;就让我们用温柔而富有童真的笔触&#xff0c;唤醒那份沉睡已久的童心吧。 在…

gitlab没有合并权限、发起合并请求

1.选择项目 2.发起合并请求 3.选择需要合并的分支&#xff0c;第一个是被合并的分支&#xff0c;第二个是合并上哪个分支 4.选择有合并权限的管理人 5.发起请求&#xff0c;通知、等待管理人合并

【Docker实战】进入四大数据库的命令行模式

上一篇我们讲了docker exec命令&#xff0c;这一次我们使用docker exec命令来进入四大数据库的命令行模式。 我们进行游戏开发或软件开发是离不开四大数据库的&#xff0c;这四大数据库分别是关系型数据库mysql、postgres&#xff0c;nosql数据库redis、mongodb。将它们容器化…