使用 CNN 训练自己的数据集

news2025/1/14 1:03:49

CNN(练习数据集)

  • 1.导包:
  • 2.导入数据集:
  • 3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:
  • 4. 查看数据集中的一部分图像,以及它们对应的标签:
  • 5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:
  • 6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:
  • 7.将数据集缓存到内存中,加快速度:
  • 8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:
  • 9.打印网络结构:
  • 10.设置优化器,定义了训练轮次和批量大小:
  • 11.训练数据集:
  • 12.画出图像:
  • 13.评估您的模型在验证数据集的性能:
  • 14.输出在验证集上的预测结果和真实值的对比:
  • 15.输出可视化报表:

  • 在网上寻找一个新的数据集,自己进行训练

1.导包:

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

输出结果:
在这里插入图片描述

2.导入数据集:

# 定义超参数
data_dir = "D:\JUANJI"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
batch_size = 30
img_height = 180
img_width = 180

输出结果:
在这里插入图片描述

3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:

#  使用image_dataset_from_directory()将数据加载到tf.data.Dataset中
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,  # 验证集0.2
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

输出结果:
在这里插入图片描述

4. 查看数据集中的一部分图像,以及它们对应的标签:

class_names = train_ds.class_names
print(class_names)
# 可视化
plt.figure(figsize=(16, 8))
for images, labels in train_ds.take(1):
    for i in range(16):
        ax = plt.subplot(4, 4, i + 1)
        # plt.imshow(images[i], cmap=plt.cm.binary)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述

5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出结果:
在这里插入图片描述

6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:

aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
            height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
            horizontal_flip=True, fill_mode="nearest")
x = aug.flow(image_batch, labels_batch)
AUTOTUNE = tf.data.AUTOTUNE

输出结果:
在这里插入图片描述

7.将数据集缓存到内存中,加快速度:

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

输出结果:
在这里插入图片描述

8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:

# 为了增加模型的泛化能力,增加了Dropout层,并将最大池化层更新为平均池化层
num_classes = 3
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width, 3)),
    layers.Conv2D(32, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(256, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(512, activation='relu'),
    layers.Dense(num_classes)
])

输出结果:
在这里插入图片描述

9.打印网络结构:

model.summary()

输出结果:
在这里插入图片描述

10.设置优化器,定义了训练轮次和批量大小:

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

EPOCHS = 100
BS = 5

输出结果:
在这里插入图片描述

11.训练数据集:

# 训练网络
# model.fit 可同时处理训练和即时扩充的增强数据。
# 我们必须将训练数据作为第一个参数传递给生成器。生成器将根据我们先前进行的设置生成批量的增强训练数据。
for images_train, labels_train in train_ds:
    continue
for images_test, labels_test in val_ds:
    continue
history = model.fit(x=aug.flow(images_train,labels_train, batch_size=BS),
                 validation_data=(images_test,labels_test),
steps_per_epoch=1,epochs=EPOCHS)

输出结果:
在这里插入图片描述

12.画出图像:

# 画出训练精确度和损失图
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, history.history["loss"], label="train_loss")
plt.plot(N, history.history["val_loss"], label="val_loss")
plt.plot(N, history.history["accuracy"], label="train_acc")
plt.plot(N, history.history["val_accuracy"], label="val_acc")
plt.title("Aug Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc='upper right')  # legend显示位置
plt.show()

输出结果:
在这里插入图片描述

13.评估您的模型在验证数据集的性能:

test_loss, test_acc = model.evaluate(val_ds, verbose=2)
print(test_loss, test_acc)

输出结果:
在这里插入图片描述

14.输出在验证集上的预测结果和真实值的对比:

#  优化2 输出在验证集上的预测结果和真实值的对比
pre = model.predict(val_ds)
for images, labels in val_ds.take(1):
    for i in range(4):
        ax = plt.subplot(1, 4, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.xticks([])
        plt.yticks([])
        # plt.xlabel('pre: ' + class_names[np.argmax(pre[i])] + ' real: ' + class_names[labels[i]])
        plt.xlabel('pre: ' + class_names[np.argmax(pre[i])])
        print('pre: ' + str(class_names[np.argmax(pre[i])]) + ' real: ' + class_names[labels[i]])
plt.show()

输出结果:
在这里插入图片描述

15.输出可视化报表:

print(labels_test)
print(labels)
print(pre)
print(class_names)
from sklearn.metrics import classification_report
# 优化1 输出可视化报表
print(classification_report(labels_test,
                          pre.argmax(axis=1),
target_names=class_names))

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1713238.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高维数组到向量的转换:两种方法的深度解析

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言:高维数组的挑战与需求 二、方法一:使用NumPy库进行展平 示…

HTML+CSS 圆形菜单

效果演示 实现了一个圆形菜单的效果,点击菜单按钮后,菜单项会从菜单按钮中心点向外展开,并且菜单项上有文字链接。可以将这段代码的效果称为“圆形菜单展开效果”。 Code <!DOCTYPE html> <html lang="en"><head><meta charset="UTF-8…

word 替换全部字母和数字为新罗马

步骤1&#xff0c;准备好一份测试文档 Adfafdafdafdafdsafdsafasdfdsa 汇总的时光发生的尬的算法的萨法asdfasfsafda大法师短发沙发上对方阿福的萨法的算法大法大方发达舒服打发到沙发上对方说 打发打发打发的负担啊大方阿道夫大法东方大厦发大水Ameti 1. Adafe我直打大噶特区…

Vue开发者工具安装

通过谷歌应用商店安装&#xff08;国外网站&#xff09; 极简插件下载&#xff08;推荐&#xff09;&#xff1a;下载 → 解压 → 点击左上角的三个小点 → 开发者模式 → 拖拽安装 → 插件详情允许访问文件 https://chrome.zzzmh.cn/index 安装步骤&#xff1a; 安装之后可…

集合的综合练习

自动点名器1&#xff1a;班级里有N个学生&#xff0c;实现随机点名器 public class test {public static void main(String [] args) {ArrayList<String> listnew ArrayList<>();//创建一个集合//在集合中添加元素Collections.addAll(list, "李明",&quo…

618必买的数码好物有哪些?盘点兼具设计与实用的数码好物分享

随着618购物节的到来&#xff0c;数码爱好者们又开始跃跃欲试&#xff0c;期待在这个年度大促中寻找到自己心仪的数码好物&#xff0c;在这个数字化时代&#xff0c;数码产品不仅是我们日常生活的必需品&#xff0c;更是提升生活品质的重要工具&#xff0c;那么在众多的数码产品…

一行命令将已克隆的本地Git仓库推送到内网服务器

一、需求背景 我们公司用gitea搭建了一个git服务器&#xff0c;其中支持win7的最高版本是v1.20.6。 我们公司的电脑在任何时候都不能连接外网&#xff0c;但是希望将一些开源的仓库移植到内网的服务器来。一是有相关代码使用的需求&#xff0c;二是可以建设一个内网能够查阅的…

【数据结构和算法】-动态规划爬楼梯

动态规划&#xff08;Dynamic Programming&#xff0c;DP&#xff09;是运筹学的一个分支&#xff0c;主要用于解决包含重叠子问题和最优子结构性质的问题。它的核心思想是将一个复杂的问题分解为若干个子问题&#xff0c;并保存子问题的解&#xff0c;以便在需要时直接利用&am…

15.Redis之持久化

0.知识引入 mysql的事务,有四个比较核心的特性. 1. 原子性 2.一致性 3.持久性 >(和持久化说的是一回事)【把数据存储在硬盘 >持久把数据存储茌内存上>不持久~】【重启进程/重启主机 之后,数据是否存在!!】 4.隔离性~ Redis 是一个 内存 数据库.把数据存储在内存中的…

运维必备的 Linux文件系统

1 前言 我们来简单看一下Linux系统的磁盘、目录、文件。 2 Linux 文件系统 在 Linux 操作系统中&#xff0c;所有被操作系统管理的资源&#xff0c;例如网络接口卡、磁盘驱动器、打印机、输入输出 设备、普通文件或是目录都被看作是一个文件。 也就是说在 Linux 系统中有…

长文总结 | Python基础知识点,建议收藏

测试基础-Python篇 基础① 变量名命名规则 - 遵循PEP8原则 普通变量&#xff1a;max_value 全局变量&#xff1a;MAX_VALUE 内部变量&#xff1a;_local_var 和关键字重名&#xff1a;class_ 函数名&#xff1a;bar_function 类名&#xff1a;FooClass 布尔类型的变量名…

21天精通FL Studio21.2.8!中文汉化全攻略方法教程

在音乐制作的世界中&#xff0c;有一款软件以其强大的功能和易用性而广受好评&#xff0c;那就是FL Studio。而最新版本的FL Studio 21更是在原有的基础上进行了全面的升级&#xff0c;为我们带来了更多的惊喜。今天&#xff0c;我们就一起来了解一下这款最新的水果软件——FL …

全球首例光伏电场网络攻击事件曝光

快速增长的光伏发电正面临日益严重的网络安全威胁。近日&#xff0c;日媒报道了首个针对光伏电场的网络攻击事件。 首例公开确认的光伏电网攻击 日本媒体《产经新闻》近日报道&#xff0c;黑客劫持了一个大型光伏电网中的800台远程监控设备(由工控电子制造商Contec生产的Solar…

超分论文走读

codeFormer 原始动机 高度不确定性&#xff0c;模糊到高清&#xff0c;存在一对多的映射纹理细节丢失人脸身份信息丢失 模型实现 训练VQGAN 从而得到HQ码本空间作为本文的离散人脸先验。为了降低LQ-HQ映射之间的不确定性&#xff0c;我们设计尽量小的码本空间和尽量短的Code…

文心智能体:基于零代码平台的智能体开发与应用

文章目录 初识文心智能体文心智能体平台优势文心智能体平台功能 创建文心智能体总结 初识文心智能体 文心智能体平台是基于文心大模型的智能体构建平台&#xff0c;为开发者提供低成本的开发方式&#xff0c;支持广大开发者根据自身行业领域、应用场景&#xff0c;采用多样化的…

isscc2024 short course4 In-memory Computing Architectures

新兴的ML加速器方法&#xff1a;内存计算架构 1. 概述 内存计算&#xff08;In-memory Computing&#xff09;架构是一种新兴的机器学习加速器方法&#xff0c;通过将计算能力集成到存储器中&#xff0c;以减少数据移动的延迟和能耗&#xff0c;从而提高计算效率和性能。这种方…

用于癌症免疫治疗的自佐剂聚胍纳米疫苗

近期&#xff0c;沈阳药科大学孙进教授团队、罗聪教授团队与新加坡国立大学陈小元教授团队共同合作在美国化学会旗下期刊《ACS nano》&#xff08;IF17.1&#xff09;上发表题为“Self-Adjuvanting Polyguanidine Nanovaccines for Cancer Immunotherapy”&#xff08;用于癌症…

Sora,开启通往世界模拟之路!

2024年2月16日&#xff0c;OpenAI发布视频生成AI大模型Sora。消息一经发出&#xff0c;业界再一次被之震撼。 OpenAI官网描述&#xff1a;Sora是一个根据文本指令生成真实与虚拟场景的AI模型&#xff0c;可根据用户指令生成时长达1分钟的高清视频&#xff0c;能生成具有多个角色…

ee trade:主力如何建仓吸筹的

主力建仓吸筹是指大型机构投资者或市场主力在股票市场中通过一系列策略和操作&#xff0c;逐步购买并积累大量股票&#xff0c;以建立或增加其在某只股票上的持仓。这个过程通常是为了在未来通过股价上涨来实现盈利。以下是一些主力可能采用的建仓吸筹策略&#xff1a; 隐蔽吸…

命运方舟 失落的方舟台服下载教程+账号注册教程(图文全攻略)

命运方舟 失落的方舟台服下载教程账号注册教程(图文全攻略) 失落的方舟&#xff0c;作为今年一款备受瞩目的MMORPG类型游戏&#xff0c;在官宣的时候就收获了一波不小的热度。这款游戏由游戏开发商Smile gate开发&#xff0c;游戏本体搭建于知名的虚幻引擎之上&#xff0c;为玩…