Tensorflow2.0笔记 - ResNet实践

news2024/10/6 6:44:40

        本笔记记录使用ResNet18网络结构,进行CIFAR100数据集的训练和验证。由于参数较多,训练时间会比较长,因此只跑了10个epoch,准确率还没有提升上去。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__


#关于ResNet的描述,可以参考如下链接:
#https://blog.csdn.net/qq_39770163/article/details/126169080
#代码基于ResNet18结构,有少许不一样
class BasicBlock(layers.Layer):
    def __init__(self, filter_num, strides = 1):
        super(BasicBlock, self).__init__()
        #卷积层1
        self.conv1 = layers.Conv2D(filter_num, (3,3), strides = strides, padding='same')
        #BN层
        self.bn1 = layers.BatchNormalization()
        #Relu层
        self.relu = layers.Activation('relu')

        #卷积层2,BN层2,
        self.conv2 = layers.Conv2D(filter_num, (3,3), strides = 1, padding='same')
        self.bn2 = layers.BatchNormalization()

        #Shortcut
        if strides != 1:
            #如果strides不为1,需要下采样
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num, (1,1), strides=strides))
        else:
            #strides为1, 直接返回原始值即可
            self.downsample = lambda x:x
        
    def call(self, inputs, training = None):
        #经过第一个卷积层,BN和Relu
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        #经过第二个卷积层
        out = self.conv2(out)
        out = self.bn2(out)

        #Shortt处理,out和输入相加
        identity = self.downsample(inputs)
        output = layers.add([out, identity])
        #再经过一个relu
        output = tf.nn.relu(output)
        return output

class ResNet(keras.Model):
    #layer_dims表示对应位置的ResBlock包含了几个BasicBlock
    #比如[2,2,2,2] => 总共4个ResBlock,每个ResBlock包含两个BasicBlock
    #num_classes表示输出的类别的个数
    def __init__(self, layer_dims, num_classes=100):
        super(ResNet, self).__init__()
        #预处理单元
        self.stem = Sequential([layers.Conv2D(64, (3,3), strides=(1,1)),
                                layers.BatchNormalization(),
                                layers.Activation('relu'),
                                layers.MaxPool2D(pool_size=(2,2), strides=(1,1), padding='same')
                               ])
        #创建中间ResBlock层
        self.layer1 = self.buildResBlock(64, layer_dims[0])
        self.layer2 = self.buildResBlock(128, layer_dims[1], strides=2)
        self.layer3 = self.buildResBlock(256, layer_dims[2], strides=2)
        self.layer4 = self.buildResBlock(512, layer_dims[3], strides=2)

        #自适应输出层
        self.avgpool = layers.GlobalAveragePooling2D()
        #全连接层
        self.fc = layers.Dense(num_classes)

    def call(self, inputs, training = None):
        x = self.stem(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        #经过avgpool => [b, 512]
        x = self.avgpool(x)
        #经过Dense => [b, 100]
        x = self.fc(x)
        return x

    def buildResBlock(self, filter_num, blocks, strides = 1):
        resBlocks = Sequential()
        resBlocks.add(BasicBlock(filter_num, strides))
        #后续的resBlock的strides都设置为1
        for _ in range(1, blocks):
            resBlocks.add(BasicBlock(filter_num))
        return resBlocks;

def ResNet18():
    return ResNet([2, 2, 2 ,2]);

def ResNet34():
    return ResNet([3, 4, 6, 3])


#加载CIFAR100数据集
#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

y_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)

batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, 
         tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))



def main():
    #创建ResNet
    resNet = ResNet18()
    resNet.build(input_shape=[None, 32, 32, 3])
    resNet.summary()
    
    #设置优化器
    optimizer = optimizers.Adam(learning_rate=1e-3)
    #进行训练
    num_epoches = 10
    for epoch in range(num_epoches):
        for step, (x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                #[b, 32, 32, 3] => [b, 100]
                logits = resNet(x)
                #标签做one_hot encoding
                y_onehot = tf.one_hot(y, depth=100)
                #计算损失
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)
            #计算梯度
            grads = tape.gradient(loss, resNet.trainable_variables)
            #更新参数
            optimizer.apply_gradients(zip(grads, resNet.trainable_variables))

            if (step % 100 == 0):
                print("Epoch[", epoch + 1, "/", num_epoches, "]: step - ", step, " loss:", float(loss))
        #进行验证
        total_samples = 0
        total_correct = 0
        for x,y in test_db:
            logits = resNet(x)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_samples += x.shape[0]
            total_correct += int(correct)

        #统计准确率
        acc = total_correct / total_samples
        print("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)

if __name__ == '__main__':
    main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1638681.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

自动安装环境shell脚本使用和运维基础使用讲解

title: 自动安装环境shell脚本使用和运维基础使用讲解 tags: [shell,linux,运维] categories: [开发记录,系统运维] date: 2024-3-27 14:10:15 description: 准备和说明 确认有网。 依赖程序集,官网只提供32位压缩包,手动编译安装后,在64位机…

动态数据结构中的表扩张性:摊还分析、伪代码与C语言实现

动态数据结构中的表扩张性:摊还分析、伪代码与C语言实现 引言表扩张性的概念摊还分析在表扩张性中的应用伪代码示例:TABLE-INSERT操作C语言实现结论 引言 在处理数据结构时,尤其是表(或数组),我们经常面临…

STM32标准库编译前置条件配置

本文基于stm32f104系列芯片,记录编程代码前需要的操作: 添加库文件 在ST官网下载标准库STM32F10x_StdPeriph_Lib_V3.5.0,解压后,得到以下界面 启动文件 进入Libraries,然后进入CMSIS,再进入CM3&#xff…

CUDA内存模型

核函数性能并不只与线程束的执行有关。 CUDA内存模型概述 GPU和CPU内存模型的主要区别是,CUDA编程模型能将内存层次结构更好地呈现给用户,能让我们显示的控制它的行为。 对程序员来说,一般有两种类型的存储器: 可编程的&#x…

【Qt QML】用CMake管理Qt工程

CMake是一个开源、跨平台的工具系列,用于构建、测试和打包软件。CMake使用简单的独立配置文件来控制软件编译过程。与许多跨平台系统不同,CMake被设计为与本地构建环境结合使用。 下面我们在CMake项目中使用Qt的最基本方法。首先,创建一个基本…

向量体系结构:向量执行时间

看《计算机体系结构 量化研究方法》做的笔记,接着上一篇写 计算机体系结构:向量体系结构介绍-CSDN博客 向量处理器工作的示例 SAXPY或DAXPY循环。 aXY SAXPY代表“单精度aX加Y”,进行单精度浮点数的运算,其中a是一个标量&#x…

测试开发工具开发 -JMeter 函数二次开发

在JMeter中开发自定义函数是一个常见的需求,允许我们扩展JMeter的功能以适应特定的测试需求。自定义函数可以用来处理数据,生成输出,或者执行特定的运算。通过JMeter函数二次开发可以帮我们解决实际测试过程中造数难的问题 用过JMeter的同学…

搭建vue3组件库(三): CSS架构之BEM

文章目录 1. 通过 JS 生成 BEM 规范名称1.1 初始化 hooks 目录1.2 创建 BEM 命名空间函数1.3 通过 SCSS 生成 BEM 规范样式 2. 测试 BEM 规范 BEM 是由 Yandex 团队提出的一种 CSS 命名方法论,即 Block(块)、Element(元素&#xf…

WORD排版常见问题与解决方案

前言 近期使用word软件进行论文排版工作,遇到了一些常见的问题,记录一下,避免遗忘。 基本配置 系统环境:win10/win11 word版本:Microsoft Office LTSC 专业增强版 2021 问题与解决方案 问题1:页眉显示内…

C——双向链表

一.链表的概念及结构 链表是一种物理存储单元上非连续、非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的。什么意思呢?意思就是链表在物理结构上不一定是连续的,但在逻辑结构上一定是连续的。链表是由一个一个的节点连…

使用递归函数,将一串数字每位数相加求和

代码结果&#xff1a; #include<stdio.h> int DigitSum(unsigned int n) {if (n > 9)return DigitSum(n / 10) (n % 10);elsereturn n; } int main() {unsigned int n;scanf("%u", &n);int sum DigitSum(n);printf("%d\n", sum);return 0; …

持续更新|UNIAPP适配APP遇到的问题以及解决方案

在使用UNIAPP开发APP的时候遇到的一些奇奇怪怪问题记录 组件样式丢失 问题&#xff1a;组件引入界面中&#xff0c;在小程序和H5环境下样式正常&#xff0c;而在APP中却出现高度异常问题 解决&#xff1a;增加view标签将组件包裹起来即可正常显示 解决前&#xff1a; 解决后…

SCI一区 | MFO-CNN-LSTM-Mutilhead-Attention多变量时间序列预测(Matlab)

SCI一区 | MFO-CNN-LSTM-Mutilhead-Attention多变量时间序列预测&#xff08;Matlab&#xff09; 目录 SCI一区 | MFO-CNN-LSTM-Mutilhead-Attention多变量时间序列预测&#xff08;Matlab&#xff09;预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现MFO-CNN…

JAVA第二周学习笔记

文章目录 JAVA第二周学习笔记IDEA方法格式带参数及返回值的方法方法的重载方法的内存 二维数组静态初始化动态初始化 面向对象类和对象如何定义类如何得到对象注意 封装封装的优点private关键字成员变量和局部变量 this关键字构造方法作用类型特点执行时机定义重载 标准javabea…

Linux进程——进程的创建(fork的原理)

前言&#xff1a;在上一篇文章中&#xff0c;我们已经会使用getpid/getppid函数来查看pid和ppid,本篇文章会介绍第二种查看进程的方法&#xff0c;以及如何创建子进程&#xff01; 本篇主要内容&#xff1a; 查看进程的第二种方法创建子进程系统调用函数fork 在开始前&#xff…

什么是哈希表(HashTable)?

目录 一、概念 二、哈希冲突 减少哈希冲突的办法&#xff1a; 1、设计合理的哈希函数 哈希函数设计原则&#xff1a; 常用的哈希函数&#xff1a; 2、降低负载因子&#xff08;必须重点掌握&#xff09; 哈希冲突的解决 第一类&#xff1a;闭散列 第二类&…

实时监控RTSP视频流并通过YOLOv5-seg进行智能分析处理

在完成RTSP推流之后&#xff0c;尝试通过开发板接收的视频流数据进行目标检测&#xff0c;编写了一个shell脚本实现该功能&#xff0c;关于视频推流和rknn模型的部署请看之前的内容或者参考官方的文档。 #!/bin/bash # 设置脚本使用的shell解释器为bashSEGMENT_DIR"./seg…

【模板】前缀和

原题链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 前缀和模板题。 前缀和中数组下标为1~n。 前缀和&#xff1a;pre[i]pre[i-1]a[i]; 某段区间 [l,r]的和&#xff1a;pre[r]-pre[l-1] 3.…

【数学建模】2024五一数学建模C题完整论文代码更新

最新更新&#xff1a;2024五一数学建模C题 煤矿深部开采冲击地压危险预测&#xff1a;建立基于多域特征融合与时间序列分解的信号检测与区间识别模型完整论文已更新 2024五一数学建模题完整代码和成品论文获取↓↓↓↓↓ https://www.yuque.com/u42168770/qv6z0d/gyoz9ou5upv…

unity制作app(2)--主界面

1.先跳转过来&#xff0c;做一个空壳&#xff01;新增场景main为4号场景&#xff01; 2.登录成功跳转到四号场景&#xff01; 2.在main场景中新建canvas&#xff0c;不同的状态计划用不同的panel来设计&#xff01; 增加canvas和底图image 3.突然输不出来中文了&#xff0c;浪…