利用VGG16网络模块进行迁移学习实现图像识别

news2024/11/19 23:13:06

​ ImageNet虽然带有”Net“,但他不是一种深度神经网络模型,它是个数据集,斯坦福大学教授李飞飞带头建立,是目前图像分类、检测、定位的最常用数据集之一。该数据集含大量数据1500万图片,2.2万类别,真彩图(RGB三通道)。

​ mnist是手写数字识别数据集,训练集包含60000 张图像和标签,而测试集包含了10000 张图像和标签,每张图片是一个28*28像素点的0 ~ 9的灰质手写数字图片,黑底白字,图像像素值为0 ~ 255,越大该点越白。

​ 本例将在ImageNet训练的VGG16模块迁移至mnist模块。

导入keras中有关神经网络的功能模块

import tensorflow
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD

导入VGG16模块

from tensorflow.keras.applications.vgg16 import VGG16

导入keras自带的mnist数据集

from tensorflow.keras.datasets import mnist

导入图像处理和计算模块

import cv2
import numpy as np

只迁移网络结构,不迁移网络权重

model_vgg = VGG16(include_top=False, weights='imagenet',
                  input_shape=(224, 224, 3))  
model = Flatten(name='flatten')(model_vgg.output)
model = Dense(10, activation='softmax')(model)
model_vgg_mnist=Model(model_vgg.input,model,name='vgg16')
print(model_vgg_mnist.summary())

其中,include_top=False 表示迁移除顶层之外的其余网络结构到自己模型中。

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 10)                250890    
=================================================================
Total params: 14,965,578
Trainable params: 14,965,578
Non-trainable params: 0

可以看到,需要训练的参数个数为14,965,578,达到千万级别。

迁移权重的优点:网络权值不需要重新训练,只需要训练输入层的网络权值即可;
迁移权重的缺点:已有权重是基于imagenet数据集进行训练的,样本种类、数据分布等与本次训练的数据集不一定相似。

同时迁移网络结构和权重

ishape = 224
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(ishape, ishape, 3))
for layer in model_vgg.layers:
    layer.trainable = False
model = Flatten(name='flatten')(model_vgg.output)
model = Dense(10, activation='softmax')(model)
model_vgg_mnist = Model(model_vgg.input, model, name='vgg16_pretrain')
print(model_vgg_mnist.summary())
Model: "vgg16_pretrain"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 10)                250890    
=================================================================
Total params: 14,965,578
Trainable params: 250,890
Non-trainable params: 14,714,688
_________________________________________________________________
None

只需要训练250890(十万级)个参数

实践

由于训练时间较长,本例中设置输入图像尺寸为(56,56)

Model: "vgg16_pretrain"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 56, 56, 3)]       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 56, 56, 64)        1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 56, 56, 64)        36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 28, 28, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 28, 28, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 28, 28, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 14, 14, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 14, 14, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 14, 14, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 14, 14, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 7, 7, 256)         0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 7, 7, 512)         1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 7, 7, 512)         2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 7, 7, 512)         2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 3, 3, 512)         0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 3, 3, 512)         2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 3, 3, 512)         2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 3, 3, 512)         2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 1, 1, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                5130      
=================================================================
Total params: 14,719,818
Trainable params: 5,130
Non-trainable params: 14,714,688
_________________________________________________________________
None

对mnist数据集进行处理

(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train[0])
print(x_train[0].shape)  # (28, 28)

将 (28, 28)变成(56,56),mnist图像是黑白的,需要转成三维数据。

ishape = 56
x_train = [cv2.cvtColor(cv2.resize (i, (ishape, ishape)),cv2.COLOR_GRAY2BGR) for i in x_train]
x_train = np.concatenate ([arr[ np.newaxis] for arr in x_train] ).astype ("float32")
x_test = [cv2.cvtColor(cv2.resize (i, (ishape, ishape)),cv2.COLOR_GRAY2BGR) for i in x_test ]
x_test = np.concatenate ([arr[ np.newaxis] for arr in x_test ] ).astype ("float32")
print(x_train.shape)#(60000, 56, 56, 3)
print(x_test.shape)#(10000, 56, 56, 3)

归一化

x_train/=255
x_test/=255

将标签y进行one-hot编码

def tran_y(y):
    y_ohe=np.zeros(10)
    y_ohe[y]=1
    return y_ohe

y_train_ohe = np.array([tran_y(y_train[i]) for i in range (len (y_train) )])
y_test_ohe = np. array([ tran_y(y_test [i]) for i in range (len (y_test) )])

引入tensorboard

tensorboard =tensorflow.keras.callbacks.TensorBoard(histogram_freq=1)

模型拟合

model_vgg_mnist.fit(x_train, y_train_ohe,validation_split=0.2,epochs=200,batch_size=128,shuffle=True,callbacks=[tensorboard])

模型保存

tensorflow.saved_model.save(model_vgg_mnist,'tflearn-vgg-mnist')

获得损失值和准确率

loss1, accuracy1 = model.evaluate(x_train, y_train_ohe)
loss2, accuracy2 = model.evaluate(x_test, y_test_ohe)

作者为节省时间,只设置了epochs=2来进行实践,结果如下:
训练集loss=0.5910102128982544,准确率=0.8633999824523926
测试集loss=0.5841793417930603,准确率=0.8659999966621399

保存的模型文件如下
在这里插入图片描述
​ 路径下生成了logs文件,打开cmd
输入

tensorboard --logdir=C:\Users\ThinkStation\Desktop\logs\train

​ 可以查看训练时的一些损失曲线等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/72112.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PPT免费放送|Zabbix峰会结束了?还有件儿事!

精彩的Zabbix峰会成功举办,这并不意味着学习交流结束,还有件儿事——17份PPT免费获取,网盘见文末。干货满满细细品味。也欢迎你留言评价! 值得一提的是:峰会中有理有据说明:Zabbix支持信创。开源免费的Zab…

IBDP学生如何申请中国香港的大学?

作为世界上最具竞争力的城市之一,香港拥有一些亚洲乃至世界上最好的大学。当然,这也使得香港成为内地学生以及国际留学生最喜爱的留学目的地之一。中国香港的教育在很大程度上是模仿英国的教育体系,但本科课程通常是英国和美国体系的混合体。…

Android 中的广播机制

一、Android广播概念: 在Android中,有一些操作完成以后,会发送广播,Android系统内部产生这些事件后广播这些事件,至于广播接收对象是否关心这些事件,以及它们如何处理这些事件,都由广播接收对象…

(附源码)ssm日语学习系统 毕业设计 271621

基于ssm日语学习系统 摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对日语学习等问题&#…

MySQL——表的内容增删查改

文章目录表的增删改查一、Create1、单行全列插入2、多行数据指定列插入3、插入否则更新4、替换二、Retrieve😊(重点)2.1 select 列2.1.1 全列查询2.1.2 指定列查询2.1.3 查询字段为表达式2.1.4 为查询结果指定别名2.1.5 结果去重2.2 where查询2.3 结果排序2.4 筛选分…

【VC7升级VC8】将vCenter Server 7.X 升级为 vCenter Server 8 (下)—— 升级步骤说明

目录前文说明3. 第一阶段升级(1)点击【升级】(2)升级介绍(3)最终中用户许可协议(4)连接到源设备(5)VC7与ESXi 证书警告(6)vCenter Ser…

【Vue】各种loader的基本配置与使用

✍️ 作者简介: 前端新手学习中。 💂 作者主页: 作者主页查看更多前端教学 🎓 专栏分享:css重难点教学 Node.js教学 从头开始学习 ajax学习 目录webpack中的loader  loader概述  打包处理css文件  打包处理less文件  打包处理图片  …

LWIP框架

目录 协议栈分层思想 1. 网络接口层 2. 网络层 3. 传输层 4. 应用层 进程模型 单进程模型 协议栈编程接口 1、Raw/Callback API 2、Netconn API 3、Socket API 协议栈分层思想 TCP/IP协议完整的包含了一系列构成互联网基础的网络协议,TCP/IP协议的开发出…

HTTP Digest Authentication 使用心得

简介 浏览器弹出这个原生的对话框,想必大家都不陌生,就是 HTTP Baisc 认证的机制。 这是浏览器自带的,遵循 RFC2617/7617 协议。但必须指出的是,遇到这界面,不一定是 Basic Authentication,也可能是 Dige…

墨门云终端行为趋势报表,泄密风险提前预警

事件响应滞后,事后再补救,为时晚矣,据IBM的数据泄露成本报告显示,加强风险监测可更快发现数据泄露行为,有效降低企业的数据泄露成本,可见建立完善的风险预警响应机制,可以避免更大的损失&#x…

5G无线技术基础自学系列 | NSA组网场景下移动性管理

素材来源:《5G无线网络规划与优化》 一边学习一边整理内容,并与大家分享,侵权即删,谢谢支持! 附上汇总贴:5G无线技术基础自学系列 | 汇总_COCOgsta的博客-CSDN博客 NSA组网场景下移动性管理涉及的相关概念…

js操作二进制数据

使用ArrayBuffer对象保存二进制数据,使用TypedArray和DataView 视图来读写数据。 ArrayBuffer代码内存中的一段数据 const buff new ArrayBuffer(4)这样就创建了一个4(byte)字节的长度的内存判断,初始值都为0 注:一般中文占2个字节&#xff…

葡聚糖修饰Hrps共价三聚肽|葡聚糖修饰CdSe量子点

葡聚糖修饰Hrps共价三聚肽|葡聚糖修饰CdSe量子点 葡聚糖修饰Hrps共价三聚肽 中文名称:葡聚糖修饰Hrps共价三聚肽 纯度:95% 存储条件:-20C,避光,避湿 外观:固体或粘性液体 包装:瓶装/袋装 溶解性&am…

爆火的 ChatGPT 会让客服岗位消失吗?

近日,由 OpenAI 推出的 ChatGPT 在全球互联网爆火。具体有多火呢?根据 OpenAI 的 CEO Sam Altman 的说法:上周三才上线的 ChatGPT,短短几天,用户数已突破 100 万大关。 那么,ChatGPT 是什么呢?…

无线充电智能车的制作

本文素材来源于宁夏大学 作者:白二曹、王瑞、穆琴、王童兵 指导老师:康彩 一、项目简介 1.功能介绍 无线充电智能车由无线充电、自动控制、红外遥控、网页显示四部分组成。 (1)流程描述 用户端浏览器访问http://127.0.0.1页面…

Cy5.5 Tyramide,Cyanine5.5 Tyramide,花青素Cy5.5 酪酰胺化学试剂供应

一:产品描述 1、名称 英文:Cy5.5 Tyramide,Cyanine5.5 Tyramide 中文:花青素Cy5.5 酪酰胺 2、CAS编号:N/A 3、所属分类:Cyanine 4、分子量:738.4 5、分子式:C48H52CIN3O2 6、…

如何用DOS命令设置ip地址及DNS

用DOS命令设置ip地址及DNS 设置/修改IP地址,子网掩码,网关的格式: netsh interface ip set address "本地连接" static 10.25.35.35 255.255.255.0 10.25.35.7 auto[more] 命令的意思是将“本地连接” ip地址设置成 10.25.35.3…

SD NAND 的 SDIO在STM32上的应用详解(下篇)

七.SDIO外设结构体 其实前面关于SDIO寄存器的讲解已经比较详细了,这里再借助于关于SDIO结构体再进行总结一遍。 标准库函数对 SDIO 外设建立了三个初始化结构体,分别为 SDIO 初始化结构体SDIO_InitTypeDef、SDIO 命令初始化结构体 SDIO_CmdInitTypeDef…

Lq93:复原 IP 地址

有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址,但是 "0.011.255.2…

Java基础:Map集合

1. Map集合概述 现实生活中,我们常会看到这样的一种集合:IP地址与主机名,身份证号与个人,系统用户名与系统用户对象等,这种一一对应的关系,就叫做映射。Java提供了专门的集合类用来存放这种对象关系的对象…