ubuntu卷积神经网络——图片数据集的制作以及制作好的数据集的使用

news2024/11/25 11:31:42

首先我事先准备好五分类的图片放在对应的文件夹,图片资源在我的gitee文件夹中链接如下:文件管理: 用于存各种数据https://gitee.com/xiaoxiaotai/file-management.git

 里面有imgs目录和npy目录,imgs就是存放5分类的图片的目录,里面有桂花、枫叶、五味子、银杏、竹叶5种植物,npy目录存放的是我用这些图片制作好的npy文件数据集,里面有32x32大小和64x64大小的npy文件。

接下来是数据集制作过程:

首先导入所需的库

import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
%matplotlib inline
import math
from tqdm import tqdm

下面是先显示本地分类中部分图片

#先显示枫叶图片
folder_path = './datas/imgs/fengye'
# 可视化图像的个数
N = 36
# n 行 n 列
n = math.floor(np.sqrt(N))

images = []
for each_img in os.listdir(folder_path)[:N]:
    img_path = os.path.join(folder_path, each_img)
    #img_bgr = cv2.imread(img_path)
    img_bgr = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1) #解决路径中存在中文的问题
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    images.append(img_rgb)

fig = plt.figure(figsize=(6, 8),dpi=80)
grid = ImageGrid(fig, 111,  # 类似绘制子图 subplot(111)
                 nrows_ncols=(n, n),  # 创建 n 行 m 列的 axes 网格
                 axes_pad=0.02,  # 网格间距
                 share_all=True
                 )

# 遍历每张图像
for ax, im in zip(grid, images):
    ax.imshow(im)
    ax.axis('off')

plt.tight_layout()
plt.show()

 输出结果如下:

 下面是输出各个图片的信息包括图片宽高、图片名、所属类别,os.chdir('../')意思是将当前路径指针指向上一个目录,可以用os.getcwd()输出当前所指路径

# 指定数据集路径
dataset_path = './datas/imgs/'
os.chdir(dataset_path)
print(os.listdir())

df = pd.DataFrame()
for fruit in tqdm(os.listdir()): # 遍历每个类别    
    os.chdir(fruit)
    for file in os.listdir(): # 遍历每张图像
        try:
            img = cv2.imread(file)
            df = df.append({'类别':fruit, '文件名':file, '图像宽':img.shape[1], '图像高':img.shape[0]}, ignore_index=True)
        except:
            print(os.path.join(fruit, file), '读取错误')
    os.chdir('../')
os.chdir('../../')
df

输出结果如下:

定义标签数字,因为数据集标签一般是数字,训练才更快

# 定义5个类别的标签
labels = {
    'wuweizi': 0,
    'fengye': 1,
    'guihua': 2,
    'zhuye': 3,
    'yinxing': 4
}

# 定义训练集和测试集的比例
train_ratio = 0.8

# 定义一个空列表用于存储训练集和测试集
train_data = []
test_data = []

 数据增强,我这里是将每一张图片缩小为64x64,你也可以改成32x32或者其他大小,要注意的是,大小越大数据集制作越久,得到的数据集大小越大。

# 定义数据增强的方法
def data_augmentation(img):
    # 随机裁剪
    img = cv2.resize(img, (256, 256))
    x = random.randint(0, 256 - 64)
    y = random.randint(0, 256 - 64)
    img = img[x:x+64, y:y+64]
    

    # 随机翻转
    if random.random() < 0.5:
        img = cv2.flip(img, 1)
    
    # 随机旋转
    angle = random.randint(-10, 10)
    M = cv2.getRotationMatrix2D((32, 32), angle, 1)
    img = cv2.warpAffine(img, M, (64, 64))
    
    return img
# 定义读取图片的方法
def read_image(path):
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = data_augmentation(img)
    img = img / 255.0
    return img

下面是给图片打上标签了,也就是每一张图片都给它标注属于哪一种类别(身份),这样卷积神经网络就可以在训练的时候知道类别,从而记住所属特征的标签值

# 遍历5个文件夹,读取图片并打上标签
for path, label in labels.items():
    files = os.listdir('./datas/imgs/'+path)
    random.shuffle(files)
    train_files = files[:int(len(files) * train_ratio)]
    test_files = files[int(len(files) * train_ratio):]
    for file in train_files:
        img = read_image(os.path.join('./datas/imgs/'+path, file))
        train_data.append((img, label))
    for file in test_files:
        img = read_image(os.path.join('./datas/imgs/'+path, file))
        test_data.append((img, label))
    # 工整地输出每一类别的数据个数
    print('类别:{} 训练集个数:{} 测试集数据:{}'.format(path, len(train_files), len(test_files)))

这里的输出结果:

现在可以看一下裁剪后的结果

df = pd.DataFrame()
for img,label in train_data: # 遍历每个类别    

#     img = cv2.imread(fruit)
    df = df.append({'类别':label, '文件名':file, '图像宽':img.shape[1], '图像高':img.shape[0]}, ignore_index=True)
df

 结果如下,我们可以看到大小已经变成64x64了,当然这是没有打乱顺序的,类别是从0开始到4:

接下来就是打乱顺序,这也是为了防止过拟合化

# 打乱训练集和测试集的顺序
random.shuffle(train_data)
random.shuffle(test_data)

 再次输出

df = pd.DataFrame()
for img,label in train_data: # 遍历每个类别    

#     img = cv2.imread(fruit)
    df = df.append({'类别':label, '文件名':file, '图像宽':img.shape[1], '图像高':img.shape[0]}, ignore_index=True)
df

 这一次的结果如下,类别顺序已经被打乱:

下面是保存训练集和测试集的数据集和标签

# 将训练集和测试集的图片和标签分别存储在numpy数组中
train_imgs = np.array([data[0] for data in train_data])
train_labels = np.array([data[1] for data in train_data])
test_imgs = np.array([data[0] for data in test_data])
test_labels = np.array([data[1] for data in test_data])

# 保存训练集和测试集
np.save('./datas/npy/32px/train_imgs_64.npy', train_imgs)
np.save('./datas/npy/32px/train_labels_64.npy', train_labels)
np.save('./datas/npy/32px/test_imgs_64.npy', test_imgs)
np.save('./datas/npy/32px/test_labels_64.npy', test_labels)

上面的数据集已经做好了,那么接下来就到模型的训练了,模型的训练我就不一一解释了,大家自己看代码,我使用的是anaconda中的jupyter工具写代码

#导库
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import urllib
import cv2

# 加载上面制作的数据集
train_imgs = np.load('./datas/npy/64px/train_imgs_64.npy')
train_labels = np.load('./datas/npy/64px/train_labels_64.npy')
test_imgs = np.load('./datas/npy/64px/test_imgs_64.npy')
test_labels = np.load('./datas/npy/64px/test_labels_64.npy')

#可以看看输出纬度
train_imgs.shape

#模型构建,这里我就构建一个简单模型
def creatAlexNet():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(64, kernel_size=(3, 3), strides=(1, 1), activation='relu', input_shape=(64, 64, 3)),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
        tf.keras.layers.Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(5, activation='softmax')
    ])
    return model

#加载模型
model = creatAlexNet()

#显示摘要
model.summary()


# 定义超参数
learning_rate = 0.001 #study 
batch_size = 100 #单次训练样本数(批次大小)
epochs = 20 #训练轮数

# 定义训练模式
model.compile(optimizer ='adam',#优化器
loss='sparse_categorical_crossentropy',#损失函数
              metrics=['accuracy'])#评估模型的方式

# 加载数据集并训练模型
history = model.fit(train_imgs, train_labels, batch_size=batch_size, epochs=epochs, 
                    validation_split = 0.2)

# 评估模型
test_loss, test_acc = model.evaluate(test_imgs, test_labels, verbose=2)
print('Test accuracy:', test_acc)

#模型测试
preds = model.predict(test_imgs)
np.argmax(preds[20])

# 可视化测试
# 定义显示图像数据及其对应标签的函数
# 图像列表
label_dict={0:"wuweizi",1:"fengye",2:"guihua",3:"zhuye",4:"yinxing"}
def plot_images_labels_prediction(images,# 标签列表
                                  labels,
                                  preds,#预测值列表
                                  index,#从第index个开始显示
                                  num = 5):  # 缺省一次显示5幅
    fig=plt.gcf()#获取当前图表,Get Current Figure 
    fig.set_size_inches(12,6)#1英寸等于2.54cm 
    if num > 10:#最多显示10个子图
        num = 10
    for i in range(0, num):
        ax = plt.subplot(2,5,i+1)#获取当前要处理的子图
        plt.tight_layout()
        ax.imshow(images[index])
        title=str(i)+','+label_dict[labels[index]]#构建该图上要显示的title信息
        if len(preds)>0:
            title +='=>' + label_dict[np.argmax(preds[index])]
        ax.set_title(title,fontsize=10)#显示图上的title信息
        index += 1 
    plt.show()

plot_images_labels_prediction(test_imgs,test_labels, preds,10,30)

# 然后保存模型
model_filename ='models/plant_model.h5'
model.save(model_filename)

# 这里是从本地加载图片对模型进行测试
from PIL import Image
import numpy as np

loaded_model = tf.keras.models.load_model('models/plant_model.h5')
label_dict={0:"wuweizi",1:"fengye",2:"guihua",3:"zhuye",4:"yinxing"}

img = Image.open('./fengye.jpeg')
img = img.resize((64, 64))
img_arr = np.array(img) / 255.0
img_arr = img_arr.reshape(1, 64, 64, 3)
pred = model.predict(img_arr)
class_idx = np.argmax(pred)
plt.title("type:{}, pre_label:{}".format(label_dict[class_idx],class_idx))
plt.imshow(img, cmap=plt.get_cmap('gray'))

# 加载模型
loaded_model = tf.keras.models.load_model('models/plant_model.h5')
# 使用模型预测浏览器上的一张图片
label_dict={0:"wuweizi",1:"fengye",2:"guihua",3:"zhuye",4:"yinxing"}

# 这里是从浏览器的网址中加载图片进行识别
url = 'https://newbbs-fd.zol-img.com.cn/t_s1200x5000/g5/M00/05/08/ChMkJ1wFsOGIcMt4AAGFQDPiUhEAAtkTQCj_EoAAYVY306.jpg'
with urllib.request.urlopen(url) as url_response:
    img_array = np.asarray(bytearray(url_response.read()), dtype=np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    img_array = cv2.resize(img, (64, 64))
    img_array = img_array / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    predict_label = np.argmax(loaded_model.predict(img_array), axis=-1)[0]
    plt.imshow(img, cmap=plt.get_cmap('gray'))
    plt.title("Predict: {},Predict_label: {}".format(label_dict[predict_label],predict_label))
    plt.xticks([])
    plt.yticks([])

本次文章就到这里,感谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/519731.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Lesson14 高级IO

前言 IO 等待 数据拷贝,比如read/recv,write/send只要在单位事件里,让等的比重减低,IO的效率就越高 五种IO模型 钓鱼小案例 阻塞式 阻塞式: 张三拿着一根鱼竿,一直在岸边钓鱼,期间一直盯着鱼竿,等待鱼上钩 非阻塞式轮询式 非阻塞式轮询式: 李四拿着一根鱼竿,在岸边钓鱼,期…

Weblogic RCE合集

文章目录 CVE-2023-21839(T3/IIOP JNDI注入)前言漏洞简单分析漏洞复现防护措施 CVE-2020-2551(RMI-IIOP RCE)漏洞简单分析漏洞复现防护措施 CVE-2017-3506(wls-wsat组件XMLDecoder反序列化漏洞)漏洞简单分析漏洞复现防护措施 CVE-2020-14882&CVE-2020-14883漏洞简单分析 CV…

2023.05.11 c高级 day3

编写一个名为myfirstshell.sh的脚本&#xff0c;它包括以下内容。 包含一段注释&#xff0c;列出您的姓名、脚本的名称和编写这个脚本的目的和当前用户说“hello 用户名”显示您的机器名 hostname显示上一级目录中的所有文件的列表显示变量PATH和HOME的值显示磁盘使用情况用id命…

算法修炼之练气篇——练气十五层

博主&#xff1a;命运之光 专栏&#xff1a;算法修炼之练气篇 前言&#xff1a;每天练习五道题&#xff0c;炼气篇大概会练习200道题左右&#xff0c;题目有C语言网上的题&#xff0c;也有洛谷上面的题&#xff0c;题目简单适合新手入门。&#xff08;代码都是命运之光自己写的…

来领略一下带头双向循环链表的风采吧

&#x1f349; 博客主页&#xff1a;阿博历练记 &#x1f4d6;文章专栏&#xff1a;数据结构与算法 &#x1f68d;代码仓库&#xff1a;阿博编程日记 &#x1f339;欢迎关注&#xff1a;欢迎友友们点赞收藏关注哦 文章目录 &#x1f344;前言&#x1f37c;双向循环链表&#x1…

Qt使用星空图作为窗口背景,点击键盘的WASD控制小飞机在上面移动。

事件函数的使用依托于Qt的事件机制&#xff0c;一个来自于外部事件的传递机制模型如下所示 信号槽虽然好用&#xff0c;但是无法包含所有的情况&#xff0c;事件函数可以起到对信号槽无法覆盖的一些时机进行补充&#xff0c;事件函数的使用无需连接。 常用的事件函数如下所示。…

设计模式5—抽象工厂模式

5.抽象工厂模式 概念 抽象工厂模式&#xff1a;提供一个创建一系列相关或相互依赖对象的接口&#xff0c;而无须指定他们具体的类。抽象工厂又称为Kit模式&#xff0c;属于对象创建型模式。 抽象工厂可以将统一产品族的单独工厂封装起来&#xff0c;在正常使用中&#xff0c…

计算机网络笔记——网络层、传输层、应用层(方老师408课程)(持续更新)

文章目录 前言网络层网络层提供的两种服务网际协议——IP虚拟互联网络IP数据报格式逐一理解整体理解IP数据报分片与长度精算 IP地址IP地址概述分类的IP地址——ABCDE分类IP的子网划分不分类的IP地址——CIDRIP地址总结 IP分组的转发网际控制报文协议——ICMP下一代网络协议——…

我用 ChatGPT 干的 18 件事!【人工智能中文站创始人:mydear麦田访谈】

新建了一个网站 https://ai.weoknow.com/ 每天给大家更新可用的国内可用chatGPT 你确定你可以使用ChatGPT吗&#xff1f; 今天我整理了18种ChatGPT的使用方法&#xff0c;让大家看看你可以使用哪些。 1.语法修正 2.文本翻译 3.语言转换 4.代码解释 5.修复代码错误 6.作为百科…

初识HTML的基础知识点!!!

初识HTML&#xff01;&#xff01;&#xff01; 一、系统构架 1.B/S构架 &#xff08;1&#xff09;B/S构架&#xff08;Browser / Server) 就是&#xff08;浏览器/服务器的交互形式&#xff09; Browser支持HTML、CSS、JavaScript &#xff08;2&#xff09;优缺点 优点…

UI--基本组件

目录 1. Designer 设计师 2. Layout 布局 3. 基本组件 3.1 QWidget 3.2 ui指针 3.3 QLabel 标签&#xff08;掌握&#xff09; 示例代码&#xff1a; dialog.h dialog.cpp 3.4 QAbstractButton 按钮类&#xff08;掌握&#xff09; 示例代码&#xff1a; dialog.ui dialog.h di…

【MyBaits】SpringBoot整合MyBatis之动态SQL

目录 一、背景 二、if标签 三、trim标签 四、where标签 五、set标签 六、foreach标签 一、背景 如果我们要执行的SQL语句中不确定有哪些参数&#xff0c;此时我们如果使用传统的就必须列举所有的可能通过判断分支来解决这种问题&#xff0c;显示这是十分繁琐的。在Spring…

linux查看服务端口号、查看端口(netstat、lsof)以及PID对应服务

linux查看服务端口号、查看端口&#xff08;netstat、lsof&#xff09; netstat - atulnp会显示所有端口和所有对应的程序&#xff0c;用grep管道可以过滤出想要的字段 -a &#xff1a;all&#xff0c;表示列出所有的连接&#xff0c;服务监听&#xff0c;Socket资料 -t &…

说服审稿人,只需牢记这 8 大返修套路!

本文作者&#xff1a;雁门飞雪 如果说科研是一场修炼&#xff0c;那么学术界就是江湖&#xff0c;投稿就是作者与审稿人或编辑之间的高手博弈。 在这一轮轮的对决中&#xff0c;有时靠的是实力&#xff0c;有时靠的是技巧&#xff0c;然而只有实力和技巧双加持的作者才能长久立…

Qt--项目打包

项目打包 一款正常的软件产品应该在任何的计算机中运行&#xff0c;不需要单独安装Qt的开发环境&#xff0c;因此需要把之前的项目打包成一个安装包。 1. 设置应用图标 设置应用程序图标的操作步骤如下所示。 1. 下载一个图标图片&#xff0c;格式要求png。&#xff08;png包含…

学习Python的day.13

输入和输出 一、输入 标准输入&#xff1a;从键盘输入 input(promptNone) # prompt: 输入的提示符,可以为空 Read a string from standard input --- 译为&#xff1a;从标准输入读入一个字符串&#xff0c;输入读取的一定是字符串&#xff0c;返回值就是一个字符串 那我们…

基于知识图谱的个性化学习资源推荐系统的设计与实现(论文+源码)_kaic

摘 要 最近几年来&#xff0c;伴随着教育信息化、个性化教育和K12之类的新观念提出,一如既往的教育方法向信息化智能化的转变&#xff0c;学生群体都对这种不受时间和地点约束的学习方式有浓厚的兴趣。而现在市面上存在的推荐系统给学生推荐资料时不符合学生个人对知识获取的…

多态与虚函数

多态与虚函数 多态的引入多态与虚函数多态编译时多态运行时多态 多态的原理静态联编和动态联编 多态的引入 学过C继承的话应该都知道在继承中存在一种菱形继承&#xff0c;假设存在一个类&#xff08;person&#xff09;&#xff0c;其派生出两个子类&#xff0c;分别是studen…

Template Method模式

文章目录 &#x1f4a1;前言分类优点 &#x1f4a1;问题引入&#x1f4a1;概念&#x1f4a1;例子&#x1f4a1;总结 &#x1f4a1;前言 此文是第一篇讲解设计模式的文章&#xff0c;而笔者我又不想另起一篇来概述设计模式的分类&#xff0c;作用&#xff0c;以及优点&#xff…

MySQL笔记(四) 函数、变量、存储过程、游标、索引、存储引擎、数据库维护、指定字符集、锁机制

MySQL笔记&#xff08;四&#xff09; 文章目录 MySQL笔记&#xff08;四&#xff09;函数文本处理函数日期和时间处理函数数值处理函数类型转换函数流程控制函数自定义函数基本语法 局部变量全局变量聚集函数 aggregate functionDISTINCT 存储过程为什么要使用使用创建 删除建…