CNN|ResNet-50

news2025/2/23 2:36:09

 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

import os,PIL,pathlib
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers,models
data_dir = "/Users/yueyishen/jupter/data/bird_photos"

data_dir = pathlib.Path(data_dir)

查看数据

image_count = len(list(data_dir.glob('*/*')))

print("图片总数为:",image_count)
图片总数为: 565

数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中

batch_size = 8
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 565 files belonging to 4 classes. Using 452 files for training. 

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 565 files belonging to 4 classes. Using 113 files for validation.

class_names = train_ds.class_names
print(class_names)

['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("瓜牛")

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(8, 224, 224, 3) (8,)

配置数据集

● shuffle() : 打乱数据

● prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。

● cache() :将数据集缓存到内存当中,加速运行

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

构建ResNet-50网络模型

from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):

    filters1, filters2, filters3 = filters

    name_base = str(stage) + block + '_identity_block_'

    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size,padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    x = layers.add([x, input_tensor] ,name=name_base + 'add')
    x = Activation('relu', name=name_base + 'relu4')(x)
    return x
    # 在残差网络中,广泛地使用了BN层;但是没有使用MaxPooling以便减小特征图尺寸,
# 作为替代,在每个模块的第一层,都使用了strides = (2, 2)的方式进行特征图尺寸缩减,
# 与使用MaxPooling相比,毫无疑问是减少了卷积的次数,输入图像分辨率较大时比较适合
# 在残差网络的最后一级,先利用layer.add()实现H(x) = x + F(x)
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'

    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)

    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3],classes=1000):

    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x =     conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x =     conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x =     conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
    x =     conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')
    
    # 加载预训练模型
    model.load_weights("/Users/yueyishen/jupter/data/resnet50_weights_tf_dim_ordering_tf_kernels.h5")

    return model

model = ResNet50()
model.summary()

编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

● 损失函数(loss):用于衡量模型在训练期间的准确率。

● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。

● 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率

model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

训练模型

epochs = 10

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

 模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("微信公众号:K同学啊")

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

 预测

# 采用加载的模型(new_model)来看预测结果

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("微信公众号:K同学啊")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy().astype("uint8"))
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")

 

总结:

  • 数据导入与预处理:首先导入必要的库,设置数据目录,查看数据总数为 565 张图片。使用 image_dataset_from_directory 方法将磁盘中的数据加载为训练集和验证集,进行数据预处理,包括设置图像大小、批次大小等,并对数据集进行打乱、预取和缓存等操作。
  • 可视化数据:通过 plt.figure 和循环展示了训练集中的部分图片,并标注了图片的类别名称。再次检查数据,打印出图像批次的形状和标签批次的形状。
  • 构建 ResNet-50 网络模型:定义了 identity_block 和 conv_block 函数,用于构建 ResNet-50 模型。该模型接收输入形状为 [224,224,3] 的图像,经过一系列卷积、批归一化、激活和残差连接等操作,最后输出分类结果。
  • 编译模型:在编译模型时,设置损失函数为 sparse_categorical_crossentropy,优化器为 adam,指标为准确率。
  • 训练模型:使用训练集和验证集对模型进行训练,设置训练轮数为 10 轮。
  • 模型评估:绘制训练和验证的准确率及损失曲线,以评估模型的性能。
  • 预测:对验证集中的部分图片进行预测,展示预测结果的类别名称

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2297967.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

吉祥汽车泰国首发,用 Unity 实现行业首创全 3D 座舱虚拟世界

11 月 19 日,均瑶集团吉祥智驱(以下简称“吉祥汽车”)首款纯电动汽车 JY AIR 在泰国首发。延续吉祥航空在飞行体验上的优势,吉祥汽车对 JY AIR 赋予了将航空级服务标准延伸至地面的使命,为用户提供一站式大出行体验。此…

【OpenCV】双目相机计算深度图和点云

双目相机计算深度图的基本原理是通过两台相机从不同角度拍摄同一场景,然后利用视差来计算物体的距离。本文的Python实现示例,使用OpenCV库来处理图像和计算深度图。 1、数据集介绍 Mobile stereo datasets由Pan Guanghan、Sun Tiansheng、Toby Weed和D…

Uniapp 原生组件层级过高问题及解决方案

文章目录 一、引言🏅二、问题描述📌三、问题原因❓四、解决方案💯4.1 使用 cover-view 和 cover-image4.2 使用 subNVue 子窗体4.3 动态隐藏原生组件4.4 使用 v-if 或 v-show 控制组件显示4.5 使用 position: fixed 布局 五、总结&#x1f38…

【数据结构初阶第十节】队列(详解+附源码)

好久不见。。。别不开心了,听听喜欢的歌吧 必须有为成功付出代价的决心,然后想办法付出这个代价。云边有个稻草人-CSDN博客 目录 一、概念和结构 二、队列的实现 Queue.h Queue.c test.c Relaxing Time! ————————————《有没…

250213-RHEL8.8-外接SSD固态硬盘

It seems that the exfat-utils package is still unavailable, even after enabling the RPM Fusion repository. This could happen if the repository metadata hasn’t been updated or if the package isn’t directly available in the RPM Fusion repository for RHEL 8…

游戏引擎学习第99天

仓库:https://gitee.com/mrxiao_com/2d_game_2 黑板:制作一些光场(Light Field) 当前的目标是为游戏添加光照系统,并已完成了法线映射(normal maps)的管道,但还没有创建可以供这些正常映射采样的光场。为了继续推进&…

Linux初始化 配置yum源

问题出现:(报错) 1 切换路径 2 备份需要操作的文件夹 3 更改 CentOS 的 YUM 仓库配置文件,以便使用阿里云的镜像源。 4 清除旧的yum缓存 5 关闭防火墙 6 生成新的yum缓存 7 更新系统软件包 8 安装软件包

【笛卡尔树】

笛卡尔树 笛卡尔树定义构建性质 习题P6453 [COCI 2008/2009 #4] PERIODNICF1913D Array CollapseP4755 Beautiful Pair[ARC186B] Typical Permutation Descriptor 笛卡尔树 定义 笛卡尔树是一种二叉树,每一个节点由一个键值二元组 ( k , w ) (k,w) (k,w) 构成。要…

Java String 类深度解析:内存模型、常量池与核心机制

目录 一、String初识 1. 字符串字面量的处理流程 (1) 编译阶段 (2) 类加载阶段 (3) 运行时阶段 2. 示例验证 示例 1:字面量直接赋值 示例 2:使用 new 创建字符串 示例 3:显式调用 intern() 注意点1: ⑴. String s1 &q…

探索顶级汽车软件解决方案:驱动行业变革的关键力量

在本文中,将一同探索当今塑造汽车行业的最具影响力的软件解决方案。从设计到制造,软件正彻底改变车辆的制造与维护方式。让我们深入了解这个充满活力领域中的关键技术。 设计软件:创新车型的孕育摇篮 车辆设计软件对于创造创新型汽车模型至…

TikTok走红全球:中国短视频平台以全新姿态登陆海外市场

在数字化浪潮中,短视频已经成为全球年轻人表达自我、分享生活的重要方式。TikTok,这个起源于中国的短视频平台,以其独特的魅力和创新的功能在全球范围内迅速走红。本文将探讨TikTok如何以全新姿态登陆海外市场,并分析其成功的关键…

计算机毕业设计Python旅游评论情感分析 NLP情感分析 LDA主题分析 bayes分类 旅游爬虫 旅游景点评论爬虫 机器学习 深度学习

温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 温馨提示:文末有 CSDN 平台官方提供的学长联系方式的名片! 作者简介:Java领…

day9手机创意软件

趣味类 in:记录趣味生活(通用) 魔漫相机:真人变漫画(通用) 活照片:让照片活过来(通用) 画中画相机:与众不同的艺术 年龄检测仪:比一比谁更年轻&#xf…

A001基于SpringBoot实现的小区物业管理系统

系统介绍 基于SpringBoot实现的小区物业管理系统是为物业管理打造的一款在线管理平台,它可以实时完成信息处理,对小区信息、住户等进行在线管理,使其系统化和规范化。 系统功能说明 1、系统共有物业、业主角色,物业拥有系统最高…

【Uniapp】关于实现下拉刷新的三种方式

在小程序、h5等地方中,常常会用到下拉刷新这个功能,今天来讲解实现这个功能的三种方式:全局下拉刷新,组件局部下拉刷新,嵌套组件下拉刷新。 全局下拉刷新 这个方式简单,性能佳,最推荐&#xf…

深入理解DeepSeek与企业实践(二):32B多卡推理的原理、硬件散热与性能实测

前言 在《深入理解 DeepSeek 与企业实践(一):蒸馏、部署与评测》文章中,我们详细介绍了深度模型的蒸馏、量化技术,以及 7B 模型的部署基础,通常单张 GPU 显存即可满足7B模型完整参数的运行需求。然而&…

数据结构-链式二叉树

文章目录 一、链式二叉树1.1 链式二叉树的创建1.2 根、左子树、右子树1.3 二叉树的前中后序遍历1.3.1前(先)序遍历1.3.2中序遍历1.3.3后序遍历 1.4 二叉树的节点个数1.5 二叉树的叶子结点个数1.6 第K层节点个数1.7 二叉树的高度1.8 查找指定的值(val)1.9 二叉树的销毁 二、层序…

【云安全】云原生- K8S etcd 未授权访问

什么是etcd? etcd 是一个开源的分布式键值存储系统,主要用于存储和管理配置信息、状态数据以及服务发现信息。它采用 Raft 共识算法,确保数据的一致性和高可用性,能够在多个节点上运行,保证在部分节点故障时仍能继续提…

AI时代的前端开发:对抗压力的利器

在飞速发展的AI时代,前端开发工程师们面临着前所未有的挑战。项目周期不断缩短,需求变化日新月异,交付压力更是与日俱增,这使得开发人员承受着巨大的压力。如何提升对抗压能力,成为摆在每一位前端工程师面前的重要课题…

在npm上传属于自己的包

前言 最近在整理代码,上传到npm方便使用,所以学习了如何在npm发布一个包,整理写成一篇文章和大家一起交流。 修改记录 更新内容更新时间文章第一版25.2.10新增“使用包”,“删除包的测试”25.2.12 1、注册npm账号 npm | Home 2…