TensorFlow2进行CIFAR-10数据集动物识别,保存模型并且进行外部下载图片测试

news2025/1/23 4:48:49

首先,你已经安装好anaconda3、创建好环境、下载好TensorFlow2模块并且下载好jupyter了,那么我们就直接打开jupyter开始进行CIFAR10数据集的训练。

第一步:下载CIFAR10数据集

下载网址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

将数据集下载到合适的路径,方便模型训练的时候调用

第二步:导入该导的库

# tensorflow1.x
import tensorflow as tf
import numpy as np
import os
from matplotlib import pyplot as plt

第三步:加载刚刚下载的数据集,如果你下载了 cifar-10-python.tar.gz那么就先解压这个压缩包,将里面的文件放入一个文件夹,我这里放在为cifar-10-batches-py目录下,所有文件如图

 然后加载该数据集

import pickle

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_data(path):
    # 读取训练数据
    train_images = []
    train_labels = []
    for i in range(1, 6):
        file = path + "/data_batch_{}".format(i)
        data = unpickle(file)
        train_images.append(data[b"data"])
        train_labels.append(data[b"labels"])
    train_images = np.concatenate(train_images)
    train_labels = np.concatenate(train_labels)
    # 读取测试数据
    file = path + "/test_batch"
    data = unpickle(file)
    test_images = data[b"data"]
    test_labels = np.array(data[b"labels"])
    # 转换数据类型
    train_images = train_images.astype(np.float32)
    test_images = test_images.astype(np.float32)
    y_train = np.array(train_labels)
    y_test = np.array(test_labels)
    # 将像素值缩放到[0, 1]范围内
    x_train = train_images/255.0
    x_test = test_images/255.0
    
    # 将标签数据转换为one-hot编码
#     train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
#     test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
    
    return (x_train, y_train), (x_test, y_test)


# 加载数据集
(train_images, train_labels), (test_images, test_labels) = load_data("../cifar_data/cifar-10-batches-py")
train_images = train_images.reshape(50000, 32, 32, 3)
test_images = test_images.reshape(10000, 32, 32, 3)

当然还有更简单的方法那就是使用TensorFlow内部模块下载数据集,如下

# 下载数据集
cifar10=tf.keras.datasets.cifar10

(x_train,y_train),(x_test,y_test)=cifar10.load_data()

x_train[0][0][0]
# 对图像images进行数字标准化
x_train=x_train.astype('float32')/255.0 
x_test = x_test.astype('float32')/ 255.0

第四步:数据集本来的标签是数字,我们可以将它转化成对应的类型名

label_dict={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

第五步:开始构建神经网络模型,这里我就简单构建一个类似AlexNet的卷积神网络模型

# 建立卷积神经网络CNN模型AlexNet
#建立Sequential线性堆叠模型
'''
Conv2D(filters=,kernel_size=,strides=,padding=,activation=,input_shape=,)
filters:卷积核数量,即输出的特征图数量。
kernel_size:卷积核大小,可以是一个整数或者一个元组,例如(3, 3)。
strides:卷积步长,可以是一个整数或者一个元组,例如(1, 1)。
padding:填充方式,可以是'same'或'valid'。'same'表示在输入图像四周填充0,保证输出特征图大小与输入图像大小相同;
        'valid'表示不填充,直接进行卷积运算。
activation:激活函数,可以是一个字符串、一个函数或者一个可调用对象。
input_shape:输入图像的形状
'''

'''
MaxPooling2D(pool_size=,strides=,padding=,)
pool_size:池化窗口大小,可以是一个整数或者一个元组,例如(2, 2)表示2x2的池化窗口。
'''


#this is a noe model,you just have to choose one or the other
def creatAlexNet():
    model = tf.keras.models.Sequential()#第1个卷积层
    model.add(tf.keras.layers.Conv2D(filters=32,
                                     kernel_size=(3,3), 
                                     input_shape=(32,32,3),
                                     activation='relu', padding='same'))
    # 防止过拟合
    model.add(tf.keras.layers.Dropout(rate=0.3))
    #第1个池化层
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))
    #第2个卷积层
    model.add(tf.keras.layers.Conv2D(filters = 64,kernel_size=(3,3), activation='relu', padding ='same'))
    # 防止过拟合
    model.add(tf.keras.layers.Dropout(rate=0.3))#第2个池化层
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))# 平坦层

    #第3个卷积层
    model.add(tf.keras.layers.Conv2D(filters = 128,kernel_size=(3,3), activation='relu', padding ='same'))
    # 防止过拟合
    model.add(tf.keras.layers.Dropout(rate=0.3))#第3个池化层
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))# 平坦层

    model.add(tf.keras.layers.Flatten())# 添加输出层
    model.add(tf.keras.layers.Dense(10,activation='softmax'))
    return model

第六步:开始加载模型

执行模型函数

model = creatAlexNet()

输出摘要

model.summary()

摘要结果如下: 

 超参数定义及模型训练

'''
model.compile(optimizer =,loss=,metrics=)
optimizer:指定优化器,可以传入字符串标识符(如'rmsprop'、'adam'等),也可以传入Optimizer类的实例。
loss:指定损失函数,可以传入字符串标识符(如'mse'、'categorical_crossentropy'等),也可以传入自定义的损失函数。
metrics:指定评估指标,可以传入字符串标识符(如'accuracy'、'mae'等),也可以传入自定义的评估函数或函数列表
'''

'''
model.fit(x=,y=,batch_size=,epochs=,verbose=,validation_data=,validation_split=,shuffle=,callbacks=)
x:训练数据,通常为一个形状为(样本数, 特征数)的numpy数组,也可以是一个包含多个numpy数组的列表。

y:标签,也是一个numpy数组或列表,长度应与x的第一维相同。

batch_size:批量大小,表示每次迭代训练的样本数,通常选择2的幂次方,比如32、64、128等。

epochs:训练轮数,一个轮数表示使用所有训练数据进行了一次前向传播和反向传播,通常需要根据实际情况调整。

verbose:输出详细信息,0表示不输出,1表示输出进度条,2表示每个epoch输出一次。

validation_data:验证数据,通常为一个形状与x相同的numpy数组,也可以是一个包含多个numpy数组的列表。

validation_split:切分验证集,将训练数据的一部分用作验证数据,取值范围在0到1之间,表示将训练数据的一部分划分为验证数据的比例。

shuffle:是否打乱训练数据,True表示每个epoch之前打乱数据,False表示不打乱数据。

callbacks:回调函数,用于在训练过程中定期保存模型、调整学习率等操作,
常用的回调函数包括ModelCheckpoint、EarlyStopping、ReduceLROnPlateau等。
'''

# 设置训练参数
train_epochs=10#训练轮数
batch_size=100#单次训练样本数(批次大小)

# 定义训练模式
model.compile(optimizer ='adam',#优化器
loss='sparse_categorical_crossentropy',#损失函数
              metrics=['accuracy'])#评估模型的方式
#训练模型
train_history = model.fit(x_train,y_train,validation_split = 0.2, epochs = train_epochs, 
                          batch_size = batch_size)

训练过程如下:

第七步:训练的损失率和成功率的可视化图

# 定义训练过程可视化函数
def visu_train_history(train_history,train_metric,validation_metric):
    plt.plot(train_history.history[train_metric])
    plt.plot(train_history.history[validation_metric])
    plt.title('Train History')
    plt.ylabel(train_metric)
    plt.xlabel('epoch')
    plt.legend(['train','validation'],loc='upper left')
    plt.show()

 损失率可视化

visu_train_history(train_history,'loss','val_loss')

 

成功率可视化 

visu_train_history(train_history,'accuracy','val_accuracy')

 

第八步:模型测试及评估

用测试集评估模型

# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('Test accuracy:', test_acc)

 模型测试,可视化测试

#model test
preds = model.predict(x_test)

可视化函数

# 定义显示图像数据及其对应标签的函数
# 图像列表
def plot_images_labels_prediction(images,# 标签列表
                                  labels,
                                  preds,#预测值列表
                                  index,#从第index个开始显示
                                  num = 5):  # 缺省一次显示5幅
    fig=plt.gcf()#获取当前图表,Get Current Figure 
    fig.set_size_inches(12,6)#1英寸等于2.54cm 
    if num > 10:#最多显示10个子图
        num = 10
    for i in range(0, num):
        ax = plt.subplot(2,5,i+1)#获取当前要处理的子图
        plt.tight_layout()
        ax.imshow(images[index])
        title=str(i)+','+label_dict[labels[index][0]]#构建该图上要显示的title信息
        if len(preds)>0:
            title +='=>' + label_dict[np.argmax(preds[index])]
        ax.set_title(title,fontsize=10)#显示图上的title信息
        index += 1 
    plt.show()

执行可视化函数

plot_images_labels_prediction(x_test,y_test, preds,15,30)

 结果如下:

第九步:模型保存及模型使用,测试外部图片

保存模型

# 保存模型
model_filename ='models/cifarCNNModel.h5'
model.save(model_filename)

加载模型,测试模型

方法一:使用TensorFlow内部模块加载图片,将dog.jpg路径换成你的图片路径

# 加载模型
loaded_model = tf.keras.models.load_model('models/cifarCNNModel.h5')

type = ("airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
label_dict={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

# 加载外来图片
img = tf.keras.preprocessing.image.load_img(
    'dog.jpg', target_size=(32, 32)
)

# 转化为numpy数组
img_array = tf.keras.preprocessing.image.img_to_array(img)

# 归一化数据
img_array = img_array / 255.0

# 维度扩展
img_array = np.expand_dims(img_array, axis=0)

# 预测类别
predictions = loaded_model.predict(img_array)
pre_label = np.argmax(predictions)
plt.title("type:{}, pre_label:{}".format(label_dict[pre_label],pre_label))
plt.imshow(img, cmap=plt.get_cmap('gray'))

 结果如下,预测结果是正确的,我这里在浏览器下载的确实是一张狗的图片 

方法二:使用PIL的库加载图片进行预测 

from PIL import Image
import numpy as np

img = Image.open('./cat.jpg')
img = img.resize((32, 32))
img_arr = np.array(img) / 255.0
img_arr = img_arr.reshape(1, 32, 32, 3)
pred = model.predict(img_arr)
class_idx = np.argmax(pred)
plt.title("type:{}, pre_label:{}".format(label_dict[class_idx],class_idx))
plt.imshow(img, cmap=plt.get_cmap('gray'))

结果如下,也是正确的,我这张图片确实是一张猫的图片

 

 方法三:从网络上加载图片进行预测,将下面的网址换成你想要预测的图片网址

# 加载模型
loaded_model = tf.keras.models.load_model('models/cifarCNNModel.h5')
# 使用模型预测浏览器上的一张图片
type = ("airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
label_dict={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

url = 'https://img1.baidu.com/it/u=1284172325,1569939558&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=580'
with urllib.request.urlopen(url) as url_response:
    img_array = np.asarray(bytearray(url_response.read()), dtype=np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    img_array = cv2.resize(img, (32, 32))
    img_array = img_array / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    predict_label = np.argmax(loaded_model.predict(img_array), axis=-1)[0]
    plt.imshow(img, cmap=plt.get_cmap('gray'))
    plt.title("Predict: {},Predict_label: {}".format(type[predict_label],predict_label))
    plt.xticks([])
    plt.yticks([])

结果如下, 这张就预测错了,明明是狗,预测成鸟(bird)去了

 

那么本篇文章CIFAR10数据集分类模型训练就到此结束,感谢大家的继续支持!

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/622814.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue.js】对Vue-element-admin做代理网关转发proxy配置

文章目录 环境配置配置vue.config.js演示为啥要这么做呢? 环境配置 .env.development # 开发环境 .env.production # 生产环境我们需要在两个环境变量文件中配置 VUE_APP_BASE_API /dev # 这里配置全局的API前置标识 开发环境我使用的/dev 生产环境用的是/prod V…

Word控件Spire.Doc 【其他】教程(8):在 Word 中嵌入多媒体文件

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下,轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具,专注于创建、编辑、转…

物联网开发中常用的几款传感器

传感器是物联网中的关键部件,在物联网开发中发挥着重要作用。目前,市场上的传感器种类繁多,它们有许多用途。有些传感器可能主要用于测量温度、压力、流量等物理量,有些则用于测量位置、距离、速度和加速度等物理量,还…

干货第一弹!多组学联合分析之代谢组FAQ

代谢组是对生物体内代谢产物全谱分析的一种研究手段,代谢产物包括核酸、蛋白质、脂类生物大分子以及其他小分子物质,目前主要是检测1000Da以下的物质。代谢组研究具有高通量的检测能力、高灵敏度和准确度、非侵入性、非破坏性、全面性、数据资源整合等特…

DIY制作隔离信号注入变压器

最近在学习模电知识,接触到了测量运放环路增益,需要使用合适的注入变压器,查找资料发现商用信号注入变压器价格昂贵,不适合个人学习使用。看到LOTO使用普通音频变压器做测试,也跟技术群友做了交流,尝试使用…

企业构建高性能Web应用的重要组件

目 录 01 出现背景 ‍‍‍‍‍‍‍ 02 PrimetonLB、PrimetonMemDB在高性能Web应用中的作用 03 与PAS的集成‍‍ 04 优势体现 05 总结 01 出现背景‍ 随着互联网的快速发展和普及,各类Web应用已成为人们日常生活的重要组成,人们对Web应用的要求从过去的…

使用QMenu和mousePressEvent制作右键弹出菜单

我需要实现一个在QTextBrowser上邮件弹出菜单的效果,如下所示: 创建QTextBrowser的子类MyTextBrowser 首先创建一个QTextBrowser的子类,MyTextBrowser,如下所示:并定义一个QMenu指针 #ifndef MYTEXTBROWSER_H #defin…

webpack打包处理字体图标、map4、map3、avi资源

一、字体图标资源的下载(阿里巴巴图标库) iconfont官网:https://www.iconfont.cn/ 这里你可以搜索你想要的字体图标,或者选择官方的图标库中查找,我这里就以官方的图标库为例: 选择几个加入购物车 点…

关于libc++_shared.so 与libstdc++、libc++的链接关系

问题点1: -lstdc 与libc_shared.so的关联; 当在makefile中引入-lstdc时,其意味着调用动态库libstdc.so, Note:动态库libstdc.so 所对应的静态库是libstdc.a; Note:当前测试libstdc.so来自于Android12的./prebuilts/gcc/linux-x86/host/x8…

图数据库实践 - 如何将图数据库应用于供应链管理

导读 当前,随着全球化的加速和供应链的复杂性增加,供应链风险管理已经成为企业日常运营中不可忽视的重要方面。由于自然灾害、贸易保护、供应商更迭等因素的影响,供应链中的任何一个环节出现问题都可能导致生产中断、物流延误、成本增加&…

结构型设计模式06-桥接模式

🧑‍💻作者:猫十二懿 ❤️‍🔥账号:CSDN 、掘金 、个人博客 、Github 🎉公众号:猫十二懿 桥接模式 1、桥接模式模式介绍 桥接模式(Bridge Pattern)是一种结构型模式之一…

ssm+java+mysql在线捐赠系统

本系统实现一个在线捐赠系统,分为用户和管理员两种用户。具体功能描述如下: 后台管理员模块包括: 1. 系统用户管理:此功能为超级管理员所有,普通管理员没有此权限,实现超级管理员可以对普通管理员信息的…

如和使用matlab进行求导 ,入门级教程

文章目录 问题如图所示运行结果如图代码分析完整代码完结撒花 问题如图所示 运行结果如图 代码分析 % 定义样本数量 n 500;这行代码定义了一个变量 n,它代表样本数量。这个变量在后面的代码中会被用到。 % 将 s 和 z 取值范围分成子区间的个数 num_intervals 40…

MySQL数据库迁移到ORACLE(持续更新)

1. 使用Oracle SQL Developer 官方 SQL Developer 23.1下载 选择Windows 64-bit with JDK 11 included安装 2.下载后解压,选择exe执行启动,启动后见图 3. 创建连接 默认支持创建Oracle连接(见下图),第三方连接需导入…

企业微信自建应用 挂载网页步骤

打开企业微信网页端,并登录 企业微信 https://work.weixin.qq.com/wework_admin/frame#index 点击应用管理 再次点击 应用,划到自建版块,点击创建应用 依次添加应用信息 点击创建应用, 添加指定网页信息

【Android Studio】Flamingo版本 更新gradle插件(AGP) 7.+到8.+

步骤 build.gradle(module) android {namespace //adddefaultConfig {applicationId }}AndroidManifest.xml 取消package属性 <?xml version"1.0" encoding"utf-8"?> <manifest xmlns:android"http://schemas.android.com/apk/res/andr…

如何设置imagedraw.draw.text的字体大小

如何设置imagedraw.draw.text的字体大小 解决方法 虽然绘制框是draw.text() 但是这个函数没有提供修改的参数 解决方法 其实在字体中已经设置了大小了&#xff0c;他是按照图像调整的&#xff0c;我就直接修改了。 参考文章

QTableWidget自定义单元格

一 自定义QTableWidget 创建一个Widget项目&#xff0c;注释掉其中的ui->setupUi(this);使用自定义的布局。 #include "widget.h" #include "ui_widget.h" #include <QTableWidget> #include <QTableWidgetItem> #include <QLineEdit&…

Vue.js中的provide和inject方法是什么,有什么区别

Vue.js中的provide和inject方法 在Vue.js中&#xff0c;provide和inject是用于父组件向子组件传递数据的一种技术。通过使用provide和inject&#xff0c;我们可以在组件树中任意层次的组件之间进行数据的传递和共享&#xff0c;从而实现复杂的数据交互和状态管理的需求。本文将…

FANUC机器人MODBUS TCP通信配置方法(示教器实物演示)

FANUC机器人MODBUS TCP通信配置方法(示教器实物演示) 机器人一侧的配置: 如下图所示,示教器上找到设置—主机通讯, 如下图所示,选择第一项TCP/IP,点击详细进入配置界面, 如下图所示,设置机器人端口1#的IP地址为192.168.1.10,子网掩码:255.255.255.0 如下图所示…