第39周:猫狗识别 2(Tensorflow实战第九周)

news2025/2/20 20:20:16

目录

前言

一、前期工作

1.1 设置GPU

1.2 导入数据

输出

二、数据预处理

2.1 加载数据

2.2 再次检查数据

2.3 配置数据集

2.4 可视化数据

三、构建VGG-16网络

3.1 VGG-16网络介绍

3.2 搭建VGG-16模型

四、编译

五、训练模型

5.1 上次程序的主要Bug

5.2 修改版如下

六、模型评估

七、预测

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客
  • 🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)

说在前面

1)本周任务:找到并处理第8周的程序问题;拔高--尝试增加数据增强部分的内容以提高准确率

2)运行环境:Python3.6、Pycharm2020、tensorflow2.4.0


一、前期工作

1.1 设置GPU

代码如下:

# 1.1 设置GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
# 打印显卡信息,确认GPU可用
print(gpus)

输出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

1.2 导入数据

代码如下:

# 1.2 导入数据
import numpy as np
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
import os,PIL,pathlib

#隐藏警告
import warnings
warnings.filterwarnings('ignore')
data_dir = "./data"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

输出

图片总数为:  3400

二、数据预处理

2.1 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset,tf.keras.preprocessing.image_dataset_from_directory():是 TensorFlow 的 Keras 模块中的一个函数,用于从目录中创建一个图像数据集(dataset)。这个函数可以以更方便的方式加载图像数据,用于训练和评估神经网络模型

测试集与验证集的关系:

  • 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  • 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集
  • 因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

# 二、数据预处理
# 2.1 加载数据
batch_size = 64
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

输出如下:

Found 3400 files belonging to 2 classes.
Using 2720 files for training.
Found 3400 files belonging to 2 classes.
Using 680 files for validation.

['cat', 'dog']

2.2 再次检查数据

代码如下:

# 2.2 再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(64, 224, 224, 3)
(64,)

2.3 配置数据集

代码如下:

# 2.3 配置数据集
AUTOTUNE = tf.data.AUTOTUNE
def preprocess_image(image,label):
    return (image/255.0,label)
# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

2.4 可视化数据

代码如下:

# 2.4 可视化数据
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10
for images, labels in train_ds.take(1):
    for i in range(8):
        ax = plt.subplot(5, 8, i + 1)
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        plt.axis("off")

输出:

三、构建VGG-16网络

3.1 VGG-16网络介绍

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG优缺点分析:

  • VGG优点:VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)。
  • VGG缺点:1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

网络结构图如下(包含了16个隐藏层--13个卷积层和3个全连接层,故称为VGG-16)

​​

3.2 搭建VGG-16模型

代码如下:

# 三、构建VGG-16网络
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout

def VGG16(nb_classes, input_shape):
    input_tensor = Input(shape=input_shape)
    # 1st block
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)
    # 2nd block
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)
    # 3rd block
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)
    # 4th block
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)
    # 5th block
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)
    # full connection
    x = Flatten()(x)
    x = Dense(4096, activation='relu',  name='fc1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)
    output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)
    model = Model(input_tensor, output_tensor)
    return model
model=VGG16(1000, (img_width, img_height, 3))
model.summary()

四、编译

代码如下:

model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

五、训练模型

5.1 上次程序的主要Bug

训练中的主要问题为acc、loss等的更新计算方式!!!

修改前:将每训练1个batch之后的损失和准确率直接记录进history_train/val_loss和history_train/val_accuracy当中,最后记录的只是整个epoch中最后1个batch所得的损失和准确率而不是整个epoch中训练数据的平均值;

# 记录训练数据,方便后面的分析
history_train_loss = []
history_train_accuracy = []
history_val_loss = []
history_val_accuracy = []
for epoch in range(epochs):
    train_total = len(train_ds)
    val_total = len(val_ds)
    with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=1, ncols=100) as pbar:
        lr = lr * 0.92
        K.set_value(model.optimizer.lr, lr)
        for image, label in train_ds:
            history = model.train_on_batch(image, label)
            train_loss = history[0]
            train_accuracy = history[1]
            pbar.set_postfix({"loss": "%.4f" % train_loss,
                              "accuracy": "%.4f" % train_accuracy,
                              "lr": K.get_value(model.optimizer.lr)})
            pbar.update(1)
        history_train_loss.append(train_loss)
        history_train_accuracy.append(train_accuracy)

    print('开始验证!')
    with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=0.3, ncols=100) as pbar:
        for image, label in val_ds:
            history = model.test_on_batch(image, label)
            val_loss = history[0]
            val_accuracy = history[1]
            pbar.set_postfix({"loss": "%.4f" % val_loss,
                              "accuracy": "%.4f" % val_accuracy})
            pbar.update(1)
        history_val_loss.append(val_loss)
        history_val_accuracy.append(val_accuracy)
    print('结束验证!')
    print("验证loss为:%.4f" % val_loss)
    print("验证准确率为:%.4f" % val_accuracy)

5.2 修改版如下

修改后: 每次处理一个 batch后,将该 batch 的损失和准确率保存在loss和accuracy列表中。计算1个epoch中所有batch的训练损失和准确率的平均值,并将均值记录到history_train/val_loss或history_train/val_accuracy中。能够更准确地反映整个训练集和验证集上的表现。

代码如下:

# 五、训练模型
from tqdm import tqdm
import tensorflow.keras.backend as K
epochs = 10
lr = 1e-4
# 记录训练数据,方便后面的分析
history_train_loss = []
history_train_accuracy = []
history_val_loss = []
history_val_accuracy = []
for epoch in range(epochs):
    train_total = len(train_ds)
    val_total = len(val_ds)
    """
    total:预期的迭代数目
    ncols:控制进度条宽度
    mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)
    """
    with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=1, ncols=100) as pbar:
        lr = lr * 0.92
        K.set_value(model.optimizer.lr, lr)
        train_loss = []
        train_accuracy = []
        for image, label in train_ds:
            # 这里生成的是每一个batch的acc与loss
            history = model.train_on_batch(image, label)
            train_loss.append(history[0])
            train_accuracy.append(history[1])
            pbar.set_postfix({"train_loss": "%.4f" % history[0],
                              "train_acc": "%.4f" % history[1],
                              "lr": K.get_value(model.optimizer.lr)})
            pbar.update(1)
        history_train_loss.append(np.mean(train_loss))
        history_train_accuracy.append(np.mean(train_accuracy))
    print('开始验证!')

    with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=0.3, ncols=100) as pbar:
        val_loss = []
        val_accuracy = []
        for image, label in val_ds:
            # 这里生成的是每一个batch的acc与loss
            history = model.test_on_batch(image, label)

            val_loss.append(history[0])
            val_accuracy.append(history[1])

            pbar.set_postfix({"val_loss": "%.4f" % history[0],
                              "val_acc": "%.4f" % history[1]})
            pbar.update(1)
        history_val_loss.append(np.mean(val_loss))
        history_val_accuracy.append(np.mean(val_accuracy))

    print('结束验证!')
    print("验证loss为:%.4f" % np.mean(val_loss))
    print("验证准确率为:%.4f" % np.mean(val_accuracy))

打印训练过程:

六、模型评估

代码如下:

# 六、模型评估
from datetime import datetime
current_time = datetime.now() # 获取当前时间
epochs_range = range(epochs)
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效
plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

训练结果可视化如下:

​​​

七、预测

代码如下:

# 七、预测
# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")
for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1, 8, i + 1)
        # 显示图片
        plt.imshow(images[i].numpy())
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0)
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])
        plt.axis("off")

输出:

1/1 [==============================] - 0s 129ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 18ms/step
1/1 [==============================] - 0s 18ms/step
1/1 [==============================] - 0s 17ms/step
1/1 [==============================] - 0s 18ms/step
1/1 [==============================] - 0s 17ms/step
1/1 [==============================] - 0s 17ms/step


总结

  • Tensorflow训练过程中打印多余信息的处理,并且引入了进度条的显示方式,更加方便及时查看模型训练过程中的情况,可以及时打印各项指标
  • 发现了上次程序的Bug,对于历次准确率和loss的保存逻辑
  • 下次继续探索采用不同数据增强方式来提高准确率的方法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2298799.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DeepSeek 概述与本地化部署【详细流程】

目录 一、引言 1.1 背景介绍 1.2 本地化部署的优势 二、deepseek概述 2.1 功能特点 2.2 核心优势 三、本地部署流程 3.1 版本选择 3.2 部署过程 3.2.1 下载Ollama 3.2.2 安装Ollama 3.2.3 选择 r1 模型 3.2.4 选择版本 3.2.5 本地运行deepseek模型 3.3.6 查看…

jenkins war Windows安装

Windows安装Jenkins 需求1.下载jenkins.war2.编写快速运行脚本3.启动Jenkins4.Jenkins使用 需求 1.支持在Windows下便捷运行Jenkins; 2.支持自定义启动参数; 3.有快速运行的脚步样板。 1.下载jenkins.war Jenkins下载地址:https://get.j…

3D打印技术:如何让古老文物重获新生?

如何让古老文物在现代社会中焕发新生是一个重要议题。传统文物保护方法虽然在一定程度上能够延缓文物的损坏,但在文物修复、展示和传播方面仍存在诸多局限。科技发展进步,3D打印技术为古老文物的保护和传承提供了全新的解决方案。我们来探讨3D打印技术如…

Vue h函数到底是个啥?

h 到底是个啥? 对于了解或学习Vue高阶组件(HOC)的同学来说,h() 函数无疑是一个经常遇到的概念。 那么,这个h() 函数究竟如何使用呢,又在什么场景下适合使用呢? 一、h 是什么 看到这个函数你可…

深入浅出 Python Logging:从基础到进阶日志管理

在 Python 开发过程中,日志(Logging)是不可或缺的调试和监控工具。合理的日志管理不仅能帮助开发者快速定位问题,还能提供丰富的数据支持,让应用更具可观测性。本文将带你全面了解 Python logging 模块,涵盖…

Android WindowContainer窗口结构

Android窗口是根据显示屏幕来管理,每个显示屏幕的窗口层级分为37层,0-36层。每层可以放置多个窗口,上层窗口覆盖下面的。 要理解窗口的结构,需要学习下WindowContainer、RootWindowContainer、DisplayContent、TaskDisplayArea、T…

2025年最新版1688平台图片搜索接口技术指南及Python实现

随着电商行业的蓬勃发展,1688作为国内领先的B2B交易平台,其商品搜索功能对于买家和卖家而言都至关重要。图片搜索作为其中的一种高级搜索方式,能够极大地提升用户的搜索体验和准确性。本文将详细介绍如何通过API接口实现1688平台的图片搜索功…

基于A*算法与贝塞尔曲线的路径规划与可视化:从栅格地图到平滑路径生成

引言 在机器人导航、自动驾驶和游戏开发等领域,路径规划是一个核心问题。如何高效地找到从起点到终点的最优路径,并且确保路径的平滑性和安全性,是许多应用场景中的关键挑战。本文将介绍一种结合A算法和贝塞尔曲线的路径规划方法,并通过Pygame实现可视化。我们将从栅格地图…

使用verilog 实现 cordic 算法 ----- 旋转模式

1-设计流程 ● 了解cordic 算法原理,公式,模式,伸缩因子,旋转方向等,推荐以下链接视频了解 cordic 算法。哔哩哔哩-cordic算法原理讲解 ● 用matlab 或者 c 实现一遍算法 ● 在FPGA中用 verilog 实现,注意…

【css】width:100%;padding:20px;造成超出100%宽度的解决办法 - box-sizing的使用方法 - CSS布局

问题 修改效果 解决方法 .xx {width: 100%;padding: 0 20px;box-sizing: border-box; } 默认box-sizing: content-box下, width 内容的宽度 height 内容的高度 宽度和高度的计算值都不包含内容的边框(border)和内边距(padding&…

贪心算法_翻硬币

蓝桥账户中心 依次遍历 不符合条件就反转 题目要干嘛 你就干嘛 #include <bits/stdc.h>#define endl \n using namespace std;int main() {ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); string s; cin >> s;string t; cin >> t;int ret 0;for ( i…

深入HBase——引入

引入 前面我们通过深入HDFS到深入MapReduce &#xff0c;从设计和落地&#xff0c;去深入了解了大数据最底层的基石——存储与计算是如何实现的。 这个专栏则开始来看大数据的三驾马车中最后一个。 通过前面我们对于GFS和MapReduce论文实现的了解&#xff0c;我们知道GFS在数…

2025年02月12日Github流行趋势

项目名称&#xff1a;data-formulator 项目地址url&#xff1a;https://github.com/microsoft/data-formulator 项目语言&#xff1a;TypeScript 历史star数&#xff1a;4427 今日star数&#xff1a;729 项目维护者&#xff1a;danmarshall, Chenglong-MS, apps/dependabot, mi…

【落羽的落羽 数据结构篇】双向链表

文章目录 一、链表的分类二、双向链表1. 结构2. 申请一个新节点3. 尾部插入数据4. 头部插入数据5. 尾部删除数据6. 头部删除数据7. 在指定位置之后插入数据8. 删除指定位置节点9. 销毁链表 一、链表的分类 链表的分类实际上要从这三个方向分析&#xff1a;是否带头、单向还是双…

Golang的并发编程问题解决思路

Golang的并发编程问题解决思路 一、并发编程基础 并发与并行 在计算机领域&#xff0c;“并发”和“并行”经常被混为一谈&#xff0c;但它们有着不同的含义。并发是指一段时间内执行多个任务&#xff0c;而并行是指同时执行多个任务。在 Golang 中&#xff0c;通过 goroutines…

剑指offer第2版:搜索算法(二分/DFS/BFS)

查找本质就是排除的过程&#xff0c;不外乎顺序查找、二分查找、哈希查找、二叉排序树查找、DFS/BFS查找 一、p39-JZ3 找出数组中重复的数字&#xff08;利用特性&#xff09; 数组中重复的数字_牛客题霸_牛客网 方法1&#xff1a;全部排序再进行逐个扫描找重复。 时间复杂…

在 CentOS 上更改 SSH 默认端口以提升服务器安全性

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall ︱vue3-element-admin︱youlai-boot︱vue-uniapp-template &#x1f33a; 仓库主页&#xff1a; GitCode︱ Gitee ︱ Github &#x1f496; 欢迎点赞 &#x1f44d; 收藏 ⭐评论 …

2025年:边缘计算崛起下运维应对新架构挑战

一、引言 随着科技的飞速发展&#xff0c;2025年边缘计算正以前所未有的速度崛起&#xff0c;给运维行业带来了全新的架构挑战。在这个充满机遇与挑战的时代&#xff0c;美信时代公司的美信监控易运维管理软件成为运维领域应对这些挑战的有力武器。 二、边缘计算崛起带来的运维…

怎么理解 Spring Boot 的约定优于配置 ?

在传统的 Spring 开发中&#xff0c;大家可能都有过这样的经历&#xff1a;项目还没开始写几行核心业务代码&#xff0c;就已经在各种配置文件中耗费了大量时间。比如&#xff0c;要配置数据库连接&#xff0c;不仅要在 XML 文件里编写冗长的数据源配置&#xff0c;还要处理事务…

学习总结2.14

深搜将题目分配&#xff0c;如果是两个题目&#xff0c;就可以出现左左&#xff0c;左右&#xff0c;右左&#xff0c;右右四种时间分配&#xff0c;再在其中找最小值&#xff0c;即是两脑共同处理的最小值 #include <stdio.h> int s[4]; int sum0; int brain[25][25]; …