Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)

news2025/1/17 23:16:27

        本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

y_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)

batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, 
         tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))


#创建CNN网络,总共4个unit,每个unit主要是两个卷积层和Max Pooling池化层
cnn_layers = [
    #unit 1
    layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 2
    layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 3
    layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 4
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 5
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),
]


def main():
    #[b, 32, 32, 3] => [b, 1, 1, 512]
    cnn_net = Sequential(cnn_layers)
    cnn_net.build(input_shape=[None, 32, 32, 3])
    
    #测试一下卷积层的输出
    #x = tf.random.normal([4, 32, 32, 3])
    #out = cnn_net(x)
    #print(out.shape)

    #创建全连接层, 输出为100分类
    fc_net = Sequential([
        layers.Dense(256, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(100, activation=None),
    ])
    fc_net.build(input_shape=[None, 512])

    #设置优化器
    optimizer = optimizers.Adam(learning_rate=1e-4)

    #记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如
    # [1, 2] + [3, 4] => [1, 2, 3, 4]
    variables = cnn_net.trainable_variables + fc_net.trainable_variables
    #进行训练
    num_epoches = 10
    for epoch in range(num_epoches):
        for step, (x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                #[b, 32, 32, 3] => [b, 1, 1, 512]
                out = cnn_net(x)
                #flatten打平 => [b, 512]
                out = tf.reshape(out, [-1, 512])
                #使用全连接层做100分类logits输出
                #[b, 512] => [b, 100]
                logits = fc_net(out)
                #标签做one_hot encoding
                y_onehot = tf.one_hot(y, depth=100)
                #计算损失
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)
            #计算梯度
            grads = tape.gradient(loss, variables)
            #更新参数
            optimizer.apply_gradients(zip(grads, variables))

            if (step % 100 == 0):
                print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))
        #进行验证
        total_samples = 0
        total_correct = 0
        for x,y in test_db:
            out = cnn_net(x)
            out = tf.reshape(out, [-1, 512])
            logits = fc_net(out)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_samples += x.shape[0]
            total_correct += int(correct)

        #统计准确率
        acc = total_correct / total_samples
        print("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
if __name__ == '__main__':
    main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1608480.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鲁棒无监督人群计数与定位

鲁棒无监督人群计数与定位 摘要1 IntroductionMethod 摘要 现有的群体计数模型需要大量的训练数据,而这些数据的标注过程耗时且繁琐。为了解决这个问题,作者提出了一种简单而有效的人群计数方法,通过采用一种名为“Segment-Everything-Every…

高精度算法(1)

前言 今天来讲一讲高精度算法,我们说一个数据类型,有它的对应范围比如int类型最多 可以包含到负2的31次方到2的31次方减一 其实大概就是20亿左右那么其他的类型也同样如此 那么,如何解决一个很大很大的数的运算呢? 我们今天介…

gemini国内能用吗

gemini国内能用吗 虽然 Gemini 的具体功能和性能还未完全公开,但基于 Google 在 AI 领域的强大背景和技术实力,已经火出圈了,很多小伙伴已经迫不及待想了解一下它有什么优势以及如何快速使用上 首先我们来讲一下gemini的优势 多模态能力&a…

43、二叉树-验证二叉搜索树

思路: 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。节点的右子树只包含 大于 当前节点的数。所有左子树和右子树自身必须也是二叉搜索树。 所以对于当前节点来说:我的左节点要小于我,我的右节点要大于我&a…

顺序表详解(C语言实现)

顺序表介绍 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构,一般情况下采用数组存 储。在数组上完成数据的增删查改。 顺序表一般可以分为: 1. 静态顺序表:使用定长数组存储元素。 2. 动态顺序表:使…

世界500强:破解“智慧核能”数智化成功转型密码

近日,实在智能携手中国核能行业协会信息化专业委员会在中国人工智能小镇成功举办“基于大模型的RPA数字员工在核能行业实战应用案例专项培训”,中国核工业集团、中国广核集团、国家电力投资集团等企事业单位共同参加。中核集团作为我国核科技工业的主体&…

screen常用命令

screen是一个在Linux系统中常用的命令行终端模拟器&#xff0c;它允许用户在一个单一终端会话中管理多个终端窗口。以下是一些常用的screen命令 1、创建一个新的screen会话并命名 screen -S <name>2、control a d &#xff1a;分离&#xff08;detach&#xff09;当前的…

TensorRT从入门到了解(2)-学习笔记

目录 1.TensorRT的高性能部署简介2.TensorRT驾驭方案3.如何正确导出onnx4.动态batch和动态宽高的实现5.实现一个自定义插件6.关于封装7.YoloV5案例8.Retinaface案例9.高性能低耦合10.YOLOX集成参考 1.TensorRT的高性能部署简介 tensorRT&#xff0c;nvidia发布的dnn推理引擎&a…

Kotlin语法快速入门--变量声明(1)

Kotlin语法入门–变量声明&#xff08;1&#xff09; 文章目录 Kotlin语法入门--变量声明&#xff08;1&#xff09;一、变量声明1、整型2、字符型3、集合3.1、创建array数组3.2、创建list集合3.3、不可变类型数组3.4、Set集合--不重复添加元素3.5、键值对集合Map 4、kotlin特有…

yolov8 区域计数

yolov8 区域计数 1. 基础2. 计数功能2.1 计数模块2.2 判断模块 3. 主代码4. 实验结果5. 源码 1. 基础 本项目是在 WindowsYOLOV8环境配置 的基础上实现的&#xff0c;测距原理可见上边文章 2. 计数功能 2.1 计数模块 在指定区域内计数模块 def count_objects_in_region(bo…

浅谈rDNS在IP情报建设中的应用

在当今数字化世界中&#xff0c;互联网已经成为人们日常生活和商业活动中不可或缺的一部分。在这个庞大而复杂的网络生态系统中&#xff0c;IP地址是连接和识别各种网络设备和服务的基础。然而&#xff0c;仅仅知道一个设备的IP地址并不足以充分理解其在网络中的角色和行为。为…

第四百六十七回

文章目录 1. 知识回顾2. 使用方法3. 示例代码4. 内容总结 我们在上一章回中介绍了"OverlayEntry组件简介"相关的内容&#xff0c;本章回中将介绍OverlayEntry组件的用法.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 知识回顾 我们在上一章回中介绍了Overlay…

【简单介绍下K-means聚类算法】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

探索AI大模型:理论、技术与应用

引言 近年来&#xff0c;随着深度学习技术的迅猛发展&#xff0c;AI大模型已经成为人工智能领域的重要研究方向和热点话题。AI大模型&#xff0c;指的是拥有巨大参数规模和强大学习能力的神经网络模型&#xff0c;如BERT、GPT等&#xff0c;这些模型在自然语言处理、计算机视觉…

腾讯一面:你了解js的沙箱环境吗?

去年的面试了&#xff0c;最近复盘了一下&#xff0c;发现菜的一批&#xff0c;有些问题一下子就答出来了&#xff0c;现在答的话&#xff0c;那时候还在瞎鸡儿答我也不知道答的什么。。。。 在 JavaScript 中&#xff0c;沙箱&#xff08;sandbox&#xff09;是一个安全机制&…

目标检测——大规模鱼类数据集

一、重要性及意义 生物多样性研究&#xff1a;鱼类是水生生态系统中的重要组成部分&#xff0c;其种类多样性对于维持生态平衡至关重要。通过对鱼类进行准确的分割和分类&#xff0c;可以更好地了解不同鱼类的生态习性、分布情况以及与其他生物的相互作用&#xff0c;进而为保护…

单片机基础知识 07

一. 键盘检测 键盘分为编码键盘和非编码键盘。 编码键盘 &#xff1a;键盘上闭合键的识别由专用的硬件编码器实现&#xff0c;并产生键编码号或者键值&#xff0c;如计算机键盘。 非编码键盘&#xff1a;靠软件编程来识别。 在单片机组成的各种系统中&#xff0c;用的较多的…

Echarts-知识图谱

Echarts-知识图谱 demo地址 打开CodePen 效果 思路 1. 生成根节点 2. 根据子节点距离与根节点的角度关系&#xff0c;生成子节点坐标&#xff0c;进而生成子节点 3. 从子节点上按角度生成对应的子节点 4. 递归将根节点与每一层级子节点连线核心代码 定义节点配置 functio…

将 Notepad++ 添加到右键菜单

目录 方式一&#xff1a;添加注册表&#xff08;手动&#xff09; 方式二&#xff1a;添加注册表&#xff08;一键添加&#xff09; 有时安装了notepad后&#xff0c;在txt文件上右键&#xff0c;在弹出的菜单栏中没有【通过 Notepad 打开】&#xff0c;如下&#xff1a; 这…

5. Django 探究CBV视图

5. 探究CBV视图 Web开发是一项无聊而且单调的工作, 特别是在视图功能编写方面更为显著. 为了减少这种痛苦, Django植入了视图类这一功能, 该功能封装了视图开发常用的代码, 无须编写大量代码即可快速完成数据视图的开发, 这种以类的形式实现响应与请求处理称为CBV(Class Base…