InceptionNet与ResNet

news2024/11/16 6:50:15

以下代码图片思路来源:
北京大学Tensor flow笔记
嗯,最近学了一下神经网络,并没有很难,主要是把代码背下来,然后掌握Tensorflow是怎么搭建网络的,Tensorflow是比pytorch好用的,我直接抄的代码里面,训练还要自己写循环,,而tensonflow直接调用fit函数即可
和老师做了一下InceptionNet还有ResNet,ResNet主要是有一条path,由于维度不同需要使用1*1卷积核结合步长还有卷积核的数量来个输入的图片降低维度
Inception结构:
在这里插入图片描述
对于一个Inception块:
在这里插入图片描述
是长这样的,其实还是像搭积木一样网上堆,如果你想改变输入input的行进方向的话,只需要在call函数里面改一下就可以:

 def call(self, x):
        print("input_shape:",x.shape)
        x1=self.c1(x)
        print("x1:",x1.shape)
        x2_1=self.c2_1(x)
        print("x2_1:",x2_1.shape)
        x2=self.c2_2(x2_1)
        print("x2:",x2.shape)
        x3_1=self.c3_1(x)
        print("x3_1:",x3_1.shape)
        x3=self.c3_2(x3_1)
        print("x3:",x3.shape)
        x4_1=self.p4_1(x)
        print("x4_1:",x4_1.shape)
        x4=self.c4_2(x4_1)
        print("x4:",x4.shape)
        y=tf.concat([x1,x2,x3,x4],axis=3)

然后在init里面把积木初始化就可以了,甚至你可以不按顺序的初始化
再来说ResNet,ResNet是怕前面的特征丢失,然后把前面的结果另外用一条线引到后面来:
在这里插入图片描述
其中虚线是需要一条下采样的path,实线是不需要下采样的
在call函数里面这样写就可以:

 def call(self, x):
        y=self.c1(x)
        y=self.c2(y)
        a=x
        if(self.path):
            a=self.c3(a)
        return y+a#残差网络输出的时候记得相加

嘶,嗯,,我这只有CPU,跑了半个小时之后训练完了一个epoch,真好,
在这里插入图片描述

整体代码:
ResNet

#本文件实现对MNIST数据集的卷积神经网络训练,实现断点续训等操作
import keras
import tensorflow as tf
from tensorflow.python.keras import Model
from tensorflow.python.keras.layers import Conv2D,Dense,Flatten,MaxPool2D,Dropout,Activation
from keras.layers import BatchNormalization


(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train=x_train.reshape(x_train.shape[0],28,28,1)
x_test=x_test.reshape(x_test.shape[0],28,28,1)

class resnet18(Model):
    def __init__(self,path,filters,kernel_size,strides):
        super(resnet18,self).__init__()
        self.path=path
        self.c1=Conv2D(kernel_size=kernel_size,strides=strides,filters=filters,padding='same')
        self.c2=Conv2D(kernel_size=kernel_size,strides=1,filters=filters,padding='same')
        if(path):
            self.c3=Conv2D(kernel_size=1,strides=strides,filters=filters,use_bias=False,padding='same')
    def call(self, x):
        y=self.c1(x)
        y=self.c2(y)
        a=x
        if(self.path):
            a=self.c3(a)
        return y+a#残差网络输出的时候记得相加
class resnet(Model):
    def __init__(self):
        super(resnet,self).__init__()
        self.block=tf.keras.models.Sequential()
        begin=64
        for i in range(4):
            for j in range(2):
                if(i!=0 and j==0):
                    self.block.add(resnet18(1,begin,3,2))
                else:
                    self.block.add(resnet18(0,begin,3,1))
            begin *= 2
    def call(self, x):
        y=self.block(x)
        return y
class simple(Model):
    def __init__(self):
        super(simple,self).__init__()
        self.c1=Conv2D(kernel_size=3,filters=64,padding='same')
        self.b1=resnet()
        self.p1=keras.layers.GlobalAvgPool2D()
        self.d1=Dense(10,activation='softmax')
    def call(self,x):
        y=self.c1(x)
        y=self.b1(y)
        y=self.p1(y)
        y=self.d1(y)
        return y
model=simple()

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model_save_path='./conv_weights/weight.ckpt'

bp=tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,save_best_only=True,save_weights_only=True)

model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),
          validation_freq=1,callbacks=bp)

InceptionNect代码:

#本文件实现对MNIST数据集的卷积神经网络训练,实现断点续训等操作
import keras
import tensorflow as tf
from tensorflow.python.keras import Model
from tensorflow.python.keras.layers import Conv2D,Dense,Flatten,MaxPool2D,Dropout,Activation
from keras.layers import BatchNormalization


(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train=x_train.reshape(x_train.shape[0],28,28,1)
x_test=x_test.reshape(x_test.shape[0],28,28,1)
class conv(Model):
    def __init__(self,output_shape,stride,size):
        super(conv,self).__init__()
        self.c1=Conv2D(output_shape,strides=stride,kernel_size=size,padding='same')
        self.b1=BatchNormalization()
        self.a1=Activation('relu')
    def call(self,x):
        y=self.c1(x)
        y=self.b1(y)
        y=self.a1(y)
        return y
#下面造Inception块
class Inception(Model):
    def __init__(self,output_shape,strides):
        super(Inception,self).__init__()
        self.c1=conv(output_shape=output_shape,stride=strides,size=1)

        self.c2_1=conv(output_shape=output_shape,stride=strides,size=1)
        self.c2_2=conv(output_shape=output_shape,stride=1,size=3)

        self.c3_1=conv(output_shape=output_shape,stride=strides,size=1)
        self.c3_2=conv(output_shape=output_shape,stride=1,size=5)

        self.p4_1=MaxPool2D(pool_size=3,strides=1,padding='same')
        self.c4_2=conv(output_shape=output_shape,stride=strides,size=1)
    def call(self, x):
        print("input_shape:",x.shape)
        x1=self.c1(x)
        print("x1:",x1.shape)
        x2_1=self.c2_1(x)
        print("x2_1:",x2_1.shape)
        x2=self.c2_2(x2_1)
        print("x2:",x2.shape)
        x3_1=self.c3_1(x)
        print("x3_1:",x3_1.shape)
        x3=self.c3_2(x3_1)
        print("x3:",x3.shape)
        x4_1=self.p4_1(x)
        print("x4_1:",x4_1.shape)
        x4=self.c4_2(x4_1)
        print("x4:",x4.shape)
        y=tf.concat([x1,x2,x3,x4],axis=3)
        return y
class simple(Model):
    def __init__(self,num,classes,conv_channel):
        super(simple,self).__init__()
        self.c1=Conv2D(filters=16,kernel_size=3,padding='same')#先过一个3*3的卷积
        self.blocks=tf.keras.models.Sequential()
        for i in range(num):
            for j in range(2):
                if(j==0):
                    self.blocks.add(Inception(conv_channel,2))
                else:
                    self.blocks.add(Inception(conv_channel,1))
            conv_channel *= 2
        self.p1=keras.layers.GlobalAvgPool2D()
        self.d1=Dense(classes,activation='softmax')
    def call(self,x):
        y=self.c1(x)
        y=self.blocks(y)
        y=self.p1(y)
        y=self.d1(y)
        return y




model=simple(2,10,20)

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model_save_path='./conv_weights/weight.ckpt'

bp=tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,save_best_only=True,save_weights_only=True)

model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),
          validation_freq=1,callbacks=bp)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UDS诊断系列介绍05-27服务

本文框架1. 系列介绍27服务概述2. 27服务请求与应答2.1 27服务请求2.2 27服务肯定应答2.3 27服务否定应答1. 系列介绍 UDS(Unified Diagnostic Services)协议,即统一的诊断服务,是面向整车所有ECU的一种诊断通信方式,…

java-操作excel

文章目录java操作Excel数据使用场景excel 03 和 07的区别POIeasyExcel解析excel表中的对象POI使用步骤POI 写数据POI 读数据计算公式easyExcel读写数据写数据读数据java操作Excel数据 在 平时 可以使用IO流对Excle进行操作 但是现在使用更加方便的第三方组件来实现 使用场景 1、…

在rhel7系统使用Mariadb

文章目录一 联系和区别二 需求三 部署安装3.1 环境准备3.2 安装软件包3.3 启动服务3.4 设置防火墙策略四 创建用户和库表4.1 登录数据库4.2 创建用户4.3 创建数据库和表五 备份和恢复5.1 备份 com 数据库5.2 模拟误删除操作5.3 恢复表一 联系和区别 Mariadb是由社区开发的一个…

4.4 集成运放的性能指标及低频等效电路

一、集成运放的性能指标 在考察集成运放的性能时,常用下列参数来描述: 1、开环差模增益 AodA_{od}Aod​ 在集成运放无外加反馈时的差模放大倍数称为差模开环增益,记作 AodA_{od}Aod​。AodΔuO/(uP−uN)A_{od}\Delta u_O/(u_P-u_N)Aod​Δ…

【Spring Cloud GateWay】ServerHttpResponseDecorator不生效

文章目录1. BUG描述2. BUG解决3. BUG分析1. BUG描述 在Spring Cloud Gateway使用编码的方式实现一个全局拦截器,在全局拦截器中想要打印响应日志。 于是自己装饰了一个具有打印日志功能的ServerHttpResponseDecorator,但是在转发后的服务返回响应的时候…

在线浏览PDF:Grapecity Documents for PDF Viewer 6.0.2

Grapecity Documents for PDF Viewer跨平台 JavaScript PDF 查看器---备注:必须配合.NET版本才能编辑PDF 使用我们的 JavaScript PDF 查看器在网络上阅读和编辑 PDF。跨浏览器和框架打开和打印。 Grapecity Documents for PDF Viewer全功能的 JavaScript PDF 查看器和 PDF 编辑…

move语言之基础学习(基本类型+表达式+变量)例子

一、基本类型 Move 的基本数据类型包括: 整型 (u8, u64, u128)、布尔型 boolean 和地址 address。 Move 不支持字符串和浮点数。 整型 整型包括 u8、u64 和 u128,我们通过下面的例子来理解整型: script { fun main() { // define empty variable, set…

python(0)计算机基础知识

文章目录计算机是什么计算机的组成计算机的使用方式windows的命令行文本文件和字符集乱码计算机是什么 在现实生活中,越来越无法离开计算机了 电脑、笔记本、手机、游戏机、汽车导航、智能电视。。。 计算机就是一个用来计算的机器 目前来讲,计算机只…

C++模板进阶+继承详解

耕耘和收获不是连贯的&#xff0c;中间还隔着很长一段时间&#xff0c;那就是坚持&#xff01;一&#xff1a;模板进阶1.1&#xff1a;非类型模板参数template<class T,size_t N> class arr { private:T _a[N]; };这里的N就跟define一样&#xff0c;属于非类型模板参数。…

MongoDB常用操作

官网地址&#xff1a;https://www.mongodb.com/docs/manual/reference/method/Date/ 实例&#xff1a;系统上运行的进程及节点集&#xff0c;一个实例可以有多个库&#xff0c;默认端口 27017。如果要在一台机器上启动多个实例&#xff0c;需要设置不同端口和不同的dbpath。库&…

第四章web服务器之httpd

文章目录第四章 web服务器1.1 www简介1.1.1 网址及HTTP简介1.1.2 HTTP协议请求的工作流程1.2 www服务器的类型1.2.1 仅提供用户浏览的单向静态网页1.2.2 提供用户互动接口的动态网站1.3 www服务器的基本配置1.4 实验1.4.1 搭建静态网站——基于http协议的静态网站1.4.2 搭建静态…

Acwing---1211.蚂蚁感冒

蚂蚁感冒1.题目2.基本思想3.代码实现1.题目 长 100 厘米的细长直杆子上有 nnn 只蚂蚁。 它们的头有的朝左&#xff0c;有的朝右。 每只蚂蚁都只能沿着杆子向前爬&#xff0c;速度是 1 厘米/秒。 当两只蚂蚁碰面时&#xff0c;它们会同时掉头往相反的方向爬行。 这些蚂蚁中…

C语言基本数据类型(一)

文章目录 前言 一、int类型 二、八进制和十六进制 三、其他整数类型 四、char 类型 五、_Bool 类型 六、 可移植类型&#xff1a;stdint.h和unttypes.h 前言 C语言基本数据类型包括声明变量、如何表示字面值常量&#xff0c;以及经典的用法。 一、int类型 C语言中包括许…

【openGauss】在openEuler(ARM架构)上安装openGauss(一主一备)

一、系统版本介绍 当前案例中的openGauss安装&#xff0c;底层操作系统为openEuler-20.03-LTS版本&#xff0c;当前openGauss对Python版本兼容性最好的是Python 3.6版本与Python 3.7版本&#xff0c;该实验使用的openEuler版本自带Python 3.7.4&#xff0c;不需要再自行安装 二…

零基础如何入门网络安全?2023年最新,建议收藏!

前言 最近收到不少关注朋友的私信和留言&#xff0c;大多数都是零基础小友入门网络安全&#xff0c;需要相关资源学习。其实看过的铁粉都知道&#xff0c;之前的文里是有过推荐过的。新来的小友可能不太清楚&#xff0c;这里就系统地叙述一遍。 01.简单了解一下网络安全 说白…

前端必会手写面试题合集

实现Event(event bus) event bus既是node中各个模块的基石&#xff0c;又是前端组件通信的依赖手段之一&#xff0c;同时涉及了订阅-发布设计模式&#xff0c;是非常重要的基础。 简单版&#xff1a; class EventEmeitter {constructor() {this._events this._events || ne…

电力系统IEEE33节点Simulink仿真研究(Matlab实现)

&#x1f4a5;&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️❤️&#x1f4a5;&#x1f4a5;&#x1f4a5; &#x1f389;作者研究&#xff1a;&#x1f3c5;&#x1f3c5;&#x1f3c5;主要研究方向是电力系统和智能算法、机器学…

arduino基本知识认识和学习资源

个人对ardunio的感觉 **像是一个模块化功能的单片机&#xff0c;编程时在单片机中就像python在计算机语言的感觉。**硬件方面的功能比较单一依赖于传感器和硬件电路&#xff1b;编程比较简单&#xff0c;所有执行的函数都已经被封装&#xff0c;所以想要成为第一个用这个库吃瓜…

【C语言刷题】猜名次、猜凶手、杨辉三角、杨氏矩阵、字符串左旋、判断是否为左旋子串

目录 一、猜名次 二、猜凶手 三、杨辉三角 解法一&#xff1a; 解法二 四、杨氏矩阵 解法一 解法二 五、字符串左旋 解法一 解法二 六、判断是否为字符串左旋字串 解法一 解法二 总结 一、猜名次 5位运动员参加了10米台跳水比赛&#xff0c;有人让他们预测比赛结果…

基于轻量级CNN开发构建学生课堂行为识别系统

其实早在之前&#xff0c;我的一些文章里面就有做过关于学生课堂行为检测识别的项目&#xff0c;感兴趣的话可以自行移步阅读&#xff1a;《yolov4-tiny目标检测模型实战——学生姿势行为检测》《基于yolov5轻量级的学生上课姿势检测识别分析系统》这些主要是偏目标检测类的项目…