卷积神经网络<二>keras实现多分支输入VGG

news2024/11/28 12:54:42

VGG的模型图

在这里插入图片描述

在这里插入图片描述

VGG使用Keras实现

这里的代码借鉴了VGG实现Keras,但是这段代码不支持多通道,并且vgg函数的扩展性不好。下面修改一下,方便进行多分支图片输入的建立,以及更见方便的调参。

# from keras.models import
from keras.layers import *
from keras.models import Input, load_model, Sequential
from keras import Model
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.losses import categorical_crossentropy
import keras.optimizers
import numpy as np
 
 
def vgg(input_shape, num_cls, filters_num, conv_nums):
    # print(input_shape)
    inputs = Input(shape=input_shape)
    x = inputs
    for i in range(len(conv_nums)):
        for j in range(conv_nums[i]):
            x = Conv2D(filters=filters_num[i], kernel_size=3, padding='same',
                       name='stage{0}_conv{1}'.format(i+1, j+1))(x)
        x = MaxPool2D((2, 2), strides=2, name='maxpool_'+str(i+1))(x)
        x = ZeroPadding2D((1, 1))(x)
    x = Flatten(name='flatten')(x)
    x = Dense(units=4096, name='dense4096_1')(x)
    x = Dense(units=4096, name='dense4096_2')(x)
    x = Dense(units=num_cls, name='dense1000', activation='softmax')(x)
    model = Model(inputs=inputs, outputs=x, name='vgg')
    model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['acc'])
    return model
 
 
def train(net_name):
    path = r'C:\Users\.keras\datasets\mnist.npz'
    with np.load(path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
 
    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32')
    num_classes = 10
    x_train = x_train / 255.
    x_test = x_test / 255.
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)
 
    batch_size = 16
    epochs = 1
 
    if net_name == 'vgg-19':
        filters_num = [64, 128, 256, 512, 512]
        conv_nums = [2, 2, 4, 4, 4]
    else:
        filters_num = [32, 64, 128, 256, 512]
        conv_nums = [2, 2, 3, 3, 3]
    vgg_model = vgg(input_shape=(28, 28, 1), num_cls=num_classes, filters_num=filters_num,
                    conv_nums=conv_nums)
    vgg_model.summary()
    vgg_model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=0.1)
    vgg_model.save('{0}-mnist.h5'.format(net_name))
    eval_res = vgg_model.evaluate(x_test, y_test)
    print(eval_res)
 
 
if __name__ == '__main__':
    train('vgg-16')

方便调参版本

  1. 把optimizer和input等重要参数进行了封装,方便调参和调用。
  2. 建立了多分支输入函数build_multy_vgg(),方便调用。
"""
@author:fuzekun
@file:VGG_Model.py
@time:2022/11/22
@description: 定义VGG的模型进行图片的训练,首先只使用rri进行训练
"""

# from keras.models import
from keras.layers import *
from keras.models import Input, load_model, Sequential
from keras import Model
from keras.datasets import mnist
from keras.utils.all_utils import to_categorical
from keras.losses import categorical_crossentropy
import keras.optimizers
import numpy as np
import tensorflow as tf
"""
这里建立模型的时候,压缩到最后就没有了,个人以为是图片太小导致的,所以240的时候可以去掉zero那一层
"""

def vgg(input_shape, num_cls, filters_num, conv_nums, multy):
    # print(input_shape)
    inputs = Input(shape=input_shape)
    x = inputs
    for i in range(len(conv_nums)):
        for j in range(conv_nums[i]):
            x = Conv2D(filters=filters_num[i], kernel_size=3, padding='same')(x)
        x = MaxPool2D((2, 2), strides=2)(x)
        if input_shape[0] < 224:
            x = ZeroPadding2D((1, 1))(x)
    x = Flatten()(x)
    x = Dense(units=4096)(x)
    x = Dense(units=4096)(x)
    if not multy:      # 单模型直接输出到类别
        x = Dense(units=num_cls, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=x, name='vgg')
    return model


def build_vgg(net_name, input_shape, num_classes, optimizer, filter_num=[], conv_nums=[], multy = False):

    if net_name == 'vgg-19':
        filters_num = [64, 128, 256, 512, 512]
        conv_nums = [2, 2, 4, 4, 4]
    else:
        filters_num = [32, 64, 128, 256, 512]
        conv_nums = [2, 2, 3, 3, 3]

    vgg_model = vgg(input_shape=input_shape, num_cls=num_classes, filters_num=filters_num,
                    conv_nums=conv_nums, multy=multy)
    if not multy:   # 多输入的不进行编译
        vgg_model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
    # vgg_model.summary()
    return vgg_model

# 创建多输入的VGG模型
def build_multy_VGG(net_name, input_shape, num_classes, optimizer, n_hiddens,
                    filter_num=[], conv_nums=[]):
    out_rri = build_vgg(net_name, input_shape, num_classes, optimizer, filter_num, conv_nums, True)
    out_edr = build_vgg(net_name, input_shape, num_classes, optimizer, filter_num, conv_nums, True)
    out_amp = build_vgg(net_name, input_shape, num_classes, optimizer, filter_num, conv_nums, True)
    # 2. 进行模型融合
    # print(out_rri.output)
    combined = concatenate([out_rri.output, out_edr.output])  # (None, 7, 7, 768)
    # print(combined)
    # 2.1融合输入
    x = Dense(n_hiddens, activation='relu')(combined)
    x = Flatten()(x)
    # 2.2最后输出
    x = Dense(num_classes, activation='softmax')(x)
    # 2.3模型定义完成
    model = Model(inputs=[out_rri.input, out_edr.input], outputs=x)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
    return model

def train(net_name):
    path = r'C:\Users\.keras\datasets\mnist.npz'
    with np.load(path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32')
    num_classes = 10
    x_train = x_train / 255.
    x_test = x_test / 255.
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    batch_size = 16
    epochs = 1

    lr = 0.001
    opt = tf.keras.optimizers.Adam(learning_rate=lr)
    model = build_vgg("vgg-16", input_shape=(28,28,1), num_classes=2, optimizer=opt)
    model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=0.1)
    model.save('{0}-mnist.h5'.format(net_name))
    eval_res = model.evaluate(x_test, y_test)
    print(eval_res)


if __name__ == '__main__':
    train('vgg-16')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/26804.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MyBatis介绍

MyBatis介绍 MyBatis 是一款优秀的持久层框架&#xff0c;它支持定制化 SQL、存储过程以及高级映射。MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集。MyBatis 可以使用简单的 XML 或注解来配置和映射原生信息&#xff0c;将接口和 Java 的 POJOs(Plain Ordi…

PMP考试自学可以吗?(含PMP备考资料)

当然是可以的&#xff0c;只要解决了“报考的35学时”这个问题&#xff0c;就只剩怎么备考的问题了。 在一般情况下&#xff0c;建议备考一到三个月&#xff0c;别给自己太长或太短的备考时间&#xff0c;前者坚持不下来&#xff0c;后者备考时间太少&#xff0c;来不及备考充…

戴尔大步进军经典量子计算混合模型

​ &#xff08;图片来源&#xff1a;网络&#xff09; 戴尔正将量子计算机融入传统IT的基础架构中&#xff0c;并向新型加速计算机开放了数据中心。这家服务器制造商为传统服务器基础设施创建了一个蓝图&#xff0c;以满足量子系统的独特需求&#xff0c;量子系统速度要比经典…

基于物联网的汽车爆胎预警系统

本设计是基于物联网的汽车爆胎预警系统的设计与实现设计&#xff0c;主要实现以下功能&#xff1a; 1&#xff0c;主机用LCD1602显示温度、气压和距离&#xff1b; 2&#xff0c;主从机间通过ZigBee进行数据的传输&#xff1b; 3&#xff0c;从机检测轮胎气压&#xff0c;温度…

zbxtable

ubuntu install zbxtable 1.新建zbxtable文件夹&#xff0c;将三件套下载到本地存放 mkdir ~/zbxtable ZbxTable: https://dl.cactifans.com/zbxtable/zbxtable-2.1.0.tar.gz ZbxTable-Web: https://dl.cactifans.com/zbxtable/web.tar.gz MS-Agent: https://dl.cactifans.co…

MybatisMybatisPlusSpringboot之最全入门学习教程笔记

1 Mybatis概述 1.1 Mybatis概念 MyBatis 是一款优秀的持久层框架&#xff0c;用于简化 JDBC 开发, &#xff08;1&#xff09;持久层&#xff1a;负责将数据到保存到数据库的那一层代码。Mybatis就是对jdbc代码进行了封装。 JavaEE三层架构&#xff1a;表现层、业务层、持久层…

ftp协议主动模式与被动模式

FTP主动模式与被动模式 主动模式&#xff1a;客户端给服务端的21控制端口发命令说&#xff0c;我要下载什么什么&#xff0c;并且还会说我已经打开了自己的某个端口&#xff0c;你就从这里把东西给我吧&#xff0c;服务器知道后就会连接客户端已打开的那个数据端口把东西传给客…

SpringBoot配置https

1.Https配置 由于HTTPS具有良好的安全性&#xff0c;在开发中得到了越来越广泛的应用&#xff0c;像微信公众号、小程序等的开发都要使用HTTPS来完成。对于个人开发者而言&#xff0c;一个HTTPS 证书的价格还是有点贵&#xff0c;但是呢&#xff0c;国内的一些云服务器厂商提供…

android studio 项目生成apk的几个问题(备忘)

终于自己做的一个小东西要做完了&#xff0c;最后一步是生成apk&#xff0c;这之前遇到几个问题备忘一下。 1、安装完成后有两个图标&#xff0c;分别是两个activity&#xff0c;卸载一个&#xff0c;另一个也没了。 原因&#xff1a;我原来做一时候没有欢迎界面&#xff0c;…

总结数据结构常用树

树&#xff1a;只有一个根节点&#xff0c;有孩子结点&#xff0c;父节点 二叉树&#xff1a;每个节点最多有两个孩子结点。 二分搜索树&#xff1a;也叫二叉排序树&#xff0c;首先它是一颗二叉树&#xff0c;且左右孩子都存在时&#xff0c;左孩子都小于当前节点值&#xf…

计算机数制(进制转换,原码,反码,补码,真值)

目录 区分进制 带小数点的进制转化 进制转换练习 符号数的表示方法 区分&#xff1a; 考点&#xff1a;给你原码转换补码&#xff0c;补码最负的数的表示&#xff0c;0的表示 原码&#xff0c;反码&#xff0c;补码练习 区分进制 注意微机原理这门课用的是后缀的方式&#xff0…

小学生python游戏编程arcade----碰撞精灵消失问题

小学生python游戏编程arcade----碰撞精灵消失问题前言碰撞精灵消失问题1、多余的精灵不能及时消失1.1 问题1.2 失败代码1.3 记录备忘1.4 代码实现2、放置位置2.1 代码放在ondraw中可以2.2 在update中也可以2.3 碰撞中3、玩家子弹击中敌坦克后的爆炸效果3.1 爆炸类3.2 爆炸列表准…

2022年铁路行业研究报告

第一章 行业概况 铁路运输是主要的陆上交通运输方式之一&#xff0c;铁路是指在综合交通运输体系中&#xff0c;用于运行火车等交通工具行驶的轨道线路。铁路运输是主要的陆上交通运输方式之一&#xff0c;是通过机车牵引车辆组成列车在铁轨上运送客或货的一种运输方式。相比其…

AVL双旋转思路分析与图解

AVL树双旋转思路分析与图解 首先我们要知道什么情况之下我们是要进行双旋转? 当最小不平衡子树为LR型或者RL型的时候, 那么什么时候最小不平衡子树是RL型或者什么时候又是LR型的? 下面我们就先给出LR型, RL型, LL型, RR型最小不平衡子树的概念: LR型最小不平衡子树: 首先拿…

Linux 动静态库

目录 静态库和动态库 gcc规则使用动静态库的规则&#xff1a; 制作静态库 使用静态库 方法1. 方法2. 制作动态库 使用动态库 方法1&#xff1a; 方法2&#xff1a; 方法3&#xff1a; 方法4&#xff1a; 进程&#xff0c;静态库&#xff0c;动态库 静态库和动态库 …

传统瀑布模型和实际瀑布模型

传统瀑布模型&#xff1a; 瀑布模型是所有模型的基础框架 特点&#xff1a; 线性的开发流程&#xff0c;不能够应对需求的变化。 必须等前一阶段的工作完成后&#xff0c;才能开始后一阶段的工作 前一阶段的输出文档就是后一阶段的输入文档&#xff0c;因此只有前一阶段的输…

Map及其实现类、锁

HashMap、HashTable、ConcurrentHashMap 区别 一.HashMap和HashTable的区别 1、两者父类不同 HashMap是继承自AbstractMap类&#xff0c;而Hashtable是继承自Dictionary类。不过它们都实现了同时实现了map、Cloneable&#xff08;可复制&#xff09;、Serializable&#xff0…

朱松纯教授场景理解相关文章简介

朱松纯教授场景理解相关文章简介 Holistic 3D Scene Parsing and Reconstruction from a Single RGB Image 基于单张图像的整体场景解译与重建 我们提出了一个计算框架来联合解译单帧RGB图像&#xff0c;通过使用一系列的随机语法模型生成的CAD模型构成整体的3D结构。具体地说…

智慧农业SaaS系统

真正的大师,永远都怀着一颗学徒的心&#xff01; 一、项目简介 智慧农业SaaS系统 二、实现功能 监控管理&#xff1a;支持海康摄像头监控。 用户管理&#xff1a;支持用户是系统操作者&#xff0c;该功能主要完成系统用户配置。 岗位管理&#xff1a;支持配置系统用户所属担…

bugku渗透测试 1 writeup(无需VPS)

靶场地址&#xff1a;BugKu渗透测试1 该靶场只需要20金币就可以开启两小时&#xff0c;算的上非常良心实惠了&#xff0c;趁着有空赶紧刷一刷题目 目录 第一场景&#xff1a; 第二场景&#xff1a; 第三场景&#xff1a; 第四场景&#xff1a; 第五场景&#xff1a; 第六…