使用Resnet残差网络对图像进行分类
(猫十二分类,模型定义、训练、保存、预测)(一)
目录
一、项目简介
二、环境说明
1、安装库
2、导入需要的库
三、分类过程
(1)、解压数据集
(2)、相关文件夹的删除与建立
(3)、图片模式检验
(4)、P、RGBA、L模式的图片转换为RGB模式
(5)、数据可视化
一、项目简介
对十二种猫进行分类,属于CV方向经典的图像分类任务。图像分类任务作为其他图像任务的基石,可以更快上手计算机视觉,锻炼深度学习基本能力。
猫十二分类项目针对猫脸识别12种猫分类数据集,对猫猫所属分类进行预测,最终将结果输出到CSV文件中。
项目亮点在于使用ResNet50残差网络,使用残差网络对猫猫进行分类,并载入相应的预训练模型即可快速上手体验完成训练任务。
使用飞桨框架,基于特定赛题下的数据完成。数据集包含12种类的猫的图片。整个数据分为训练集与测试集。
训练集: 提供高清彩色图片以及图片所属的分类,共有2160张猫的图片,含标注文件。
测试集: 仅提供彩色图片,共有240张猫的图片,不含标注文件。
结果文件为CSV文件格式,命名为result.csv,文件内的字段按照指定格式写入。每行内容格式为:文件名,分类结果,文件名和分类结果使用英文逗号","分隔;
文件格式:WMgOhwZzacY023lCusqnBxIdibpkT5GP.jp,0
其中,前半部分为【图片路径】,后半部分为【类别编号】,数据列以逗号分隔,每一行数据都以回车符结束。结果文件总行数应为240;
此为初级入门,代码部分参照网络各种示例。地址见下文附录。本文是个人学习练习记录,代码归属原创作者。
二、环境说明
1、安装库
除了常用的数据科学包,下面代码里我们需要用到paddlex等库,提前安装一下。
!pip install paddlex
!pip install paddleslim
!pip install pyecharts
运行时长:54.291秒结束时间:2023-04-29 13:18:24
2、导入需要的库
导入需要的数据科学包,深度学习包
# 数据科学包
import random # 随机切分数据集
import numpy as np # 常用数据科学包
import os
from PIL import Image # 图像读取
import matplotlib.pyplot as plt # 代码中快速验证
import cv2 # 图像包
import pandas as pd
import shutil # 文件文档处理库
import imghdr # 检测图片类型
# 深度学习包
import paddle
import paddle.vision.transforms as T # 数据增强
from paddle.io import Dataset, DataLoader # 定义数据集
import paddlex as pdx
from paddle.regularizer import L2Decay # L2 权重衰减正则化
运行时长:5毫秒结束时间:2023-04-29 13:19:13
三、分类过程
(1)、解压数据集
# 解压数据集
!unzip -q /home/aistudio/data/data10954/cat_12_train.zip -d data/data10954/
!unzip -q /home/aistudio/data/data10954/cat_12_test.zip -d data/data10954/
运行时长:3.877秒结束时间:2023-04-29 13:19:24
我们打开数据集解压的 /home/aistudio/data/data10954/cat_12_train/ 目录看一下,发现里面都是.jpg文件,这些就是猫咪的图片,整个数据可作为训练集与测试集使用。随便打开几张图片,
素材是彩色格式的猫咪图片。
(2)、相关文件夹的删除与建立
## 相关文件夹的删除与建立
!rm -rf data/data10954/ImageNetDataset # 删除文件夹,防止多次运行时出错
for i in range(12):
cls_path = os.path.join('data/data10954/ImageNetDataset/', '%02d' % int(i)) # 拼接路径
if not os.path.exists(cls_path):
os.makedirs(cls_path) # 创建文件夹
!ls data/data10954/ImageNetDataset # 列出文件夹(linux语句)
##生成文件名和类别的一一对应关系,之后将根据类别cls将图片放入目标文件夹:data/data10954/ImageNetDataset/*/*.jpg。
train_df = pd.read_csv('data/data10954/train_list.txt', header=None, sep='\t') # 读取测试集标签
train_df.columns = ['name', 'cls'] # 返回列索引列表
train_df['name'] = train_df['name'].apply(lambda x: str(x).strip().split('/')[-1]) # 切分文件名,舍去cat_12_train/
train_df['cls'] = train_df['cls'].apply(lambda x: '%02d' % int(str(x).strip())) # 使图片标签类别变成2位数字
00 01 02 03 04 05 06 07 08 09 10 11
运行时长:1.046秒结束时间:2023-04-29 13:19:31
(3)、图片模式检验
【 图片模式主要有以下几种】:
1、RGB 模式
RGB三原色是指红绿蓝,是光的三原色,
三原色是指所有的颜色都有这三种色彩混合而成。(三原色是依据人类视觉定义的,不存在绝对的三原色)。
根据不同的亮度值,所有颜色可以用三种颜色混合得到。将每种颜色分为0~255,共(256x256x256=16777216种颜色)。常见的24bit色彩大概是1678万种,也就是常见的1600万真彩色。
其中,纯黑色(0,0,0)和纯白色(255,255,255)
RGB 模式为真色彩模式,打印需要更改为 CMYK模式, 注意数值溢出的问题。
2、HSB 模式
建立基于人类感觉颜色的方式,
HSB色彩模式:
- H(hue)表示色相;
- S(saturation)表示饱和度;
- B(brightness)表示明度。
- HSB是色相、饱和度、明度的相应英文首字母缩写。
3、HSV模式
它比 RGB 更接近人们对彩色的感知经验。直观地表达颜色的色调、鲜艳程度和明暗程度,方便进行颜色的对比。
- HSV 表达彩色图像的方式由三个部分组成:
Hue(色调、色相)
Saturation(饱和度、色彩纯净度)
Value(明度)
4、HLS 模式
HLS 和 HSV 比较类似。HLS 也有三个分量,hue(色相)、saturation(饱和度)、lightness(亮度)。
HLS 和 HSV 的区别就是最后一个分量不同,HLS 的是 light(亮度),HSV 的是 value(明度)。
HLS 中的 L 分量为亮度,亮度为100,表示白色,亮度为0,表示黑色;
HSV 中的 V 分量为明度,明度为100,表示光谱色,明度为0,表示黑色。
5、CMYK模式
CMYK是印刷四色模式,彩色印刷时采用的一种套色模式,因为不能保证纯度,所以需要黑。利用色料的三原色混色原理,加上黑色油墨,共计四种颜色混合叠加,形成所谓“全彩印刷”。
CMYK 颜色是青色、品红色、黄色和黑色的组合。CMYK (CYAN-MAGENTA-YELLOW-BLACK INK): 青色 - 品红 - 黄色 - 黑色
C代表青色,
M代表洋红色,
Y代表黄色,
K代表黑色。
6、Lab模式
Lab模式也是由三个通道组成,
- L通道是明度。
- a通道的颜色是从红色到深绿;
- b通道则是从蓝色到黄色。
两个分量的变化都是从-120到+120。
当a=0、 b=0时显示灰色,同时L=100时为白色,L=0时为黑色。
7、灰度模式,只有灰度, 所有颜色转化为灰度值,见L,I,F。
灰度图与彩色图不同,彩色图中一个像素通常用几个值同时表示,灰度图一个像素只有一个值:即亮度(也叫灰阶)。最常见的是256级灰阶,一个像素用1Byte表示,即0~255,当然像素值=0,表示这是个纯黑点,像素值=255,这是一个纯白点。高精度的灰阶图,会用更多的Byte来表示一个像素值。
8、索引模式
索引模式索引图像是一种把像素值直接作为RGB调色板下标的图像。索引图像可把像素值“直接映射”为调色板数值。索引模式和灰度模式比较类似,它的每个象素点也可以有256种颜色容量,但它可以负载彩色。索引的图像只支持一个图层,并且只有一个索引彩色通道。由于它最多只能有256种彩色,所以它所形成的文件相对其它彩色要小得多。索引模式主要用于网络上的图片传输和一些对图像象素、大小等有严格要求的地方。
9、多通道模式(Multichannel)
多通道模式,删除RGB,CMYK,Lab中某一个通道后,会转变为多通道,多通道用于处理特殊打印,它的每个通道都为256级灰度通道。
多通道模式对有特殊打印要求的图像非常有用。例如,如果图像中只使用了一两种或两三种颜色时,使用多通道模式可以减少印刷成本并保证图像颜色的正确输出。
10、8位/16位通道模式
在灰度RGB或CMYK模式下,可以使用16位通道来代替默认的8位通道。根据默认情况,8位通道中包含256个色阶,如果增到16位,每个通道的色阶数量为65536个,这样能得到更多的色彩细节。Photoshop可以识别和输入16位通道的图像,但对于这种图像限制很多,所有的滤镜都不能使用,另外16位通道模式的图像不能被印刷。
11、双色调模式(Duotone)
双色调模式采用2-4种彩色油墨来创建由双色调(2种颜色)、三色调(3种颜色)和四色调(4种颜色)混合其色阶来组成图像。在将灰度图像转换为双色调模式的过程中,可以对色调进行编辑,产生特殊的效果。而使用双色调模式最主要的用途是使用尽量少的颜色表现尽量多的颜色层次,这对于减少印刷成本是很重要的,因为在印刷时,每增加一种色调都需要更大的成本。
12、位图模式(Bitmap)
位图模式用两种颜色(黑和白)来表示图像中的像素。位图模式的图像也叫作黑白图像。因为其深度为1,也称为一位图像。由于位图模式只用黑白色来表示图像的像素,在将图像转换为位图模式时会丢失大量细节,
在宽度、高度和分辨率相同的情况下,位图模式的图像尺寸最小,约为灰度模式的1/7和RGB模式的1/22以下。
13、P(pallete)模式
P(pallete)模式:P代表palette,调色板模式,也就是图片中会包含一个调色表的列表,每一个像素位置放的只是一个index,那么这个像素要展示的颜色就是调色板中第index位置展示的颜色。把原来单像素占用24(32)个bit的RGB(A)真彩图片中的像素值,重映射到了8bit长,即0~255的数值范围内。而这套映射关系,就是属于这张图的所谓“调色板”(Pallete)。
14、RGBA模式
RGBA是代表Red(红色)Green(绿色)Blue(蓝色)和Alpha的色彩空间。
虽然它有的时候被描述为一个颜色空间,但是它其实仅仅是RGB模型的附加了额外的信息。采用的颜色是RGB,可以属于任何一种RGB颜色空间,alpha通道一般用作不透明度参数。如果一个像素的alpha通道数值为0%,那它就是完全透明的(也就是看不见的),而数值为100%则意味着一个完全不透明的像素(传统的数字图像)。在0%和100%之间的值则使得像素可以透过背景显示出来,就像透过玻璃(半透明性),这种效果是简单的二元透明性(透明或不透明)做不到的。它使数码合成变得容易。alpha通道值可以用百分比、整数或者像RGB参数那样用0到1的实数表示。
【 查看图片模式 】
#查看图片模式
img =Image.open('/home/aistudio/data/data10954/cat_12_train/03j9aZ5Gkq7vMDRnVQFwfbrHx8TEeoch.jpg')
print("图片模式:",img.mode)
plt.imshow(img)
plt.show()
RGB格式的图片:
img = Image.open('data/data10954/cat_12_train/F3VnNwb2K9tgMWLodrXl1f6PIEjYqhy8.jpg')
print("图片模式:",img.mode)
plt.imshow(img)
plt.show()
L格式的图片:
img = Image.open('data/data10954/cat_12_train/tO6cKGH8uPEayzmeZJ51Fdr2Tx3fBYSn.jpg')
print("图片模式:",img.mode)
plt.imshow(img)
plt.show()
P格式的图片:
img = Image.open('data/data10954/cat_12_train/ulFBEZNRQrxn57voHAJ4UG6Mct2sw1Cj.jpg')
print("图片模式:",img.mode)
plt.imshow(img)
plt.show()
RGBA格式的图片:
运行时长:1.032秒结束时间:2023-04-29 13:19:38
(4)、P、RGBA、L模式的图片转换为RGB模式
## P、RGBA、L模式的图片转换为RGB模式
for i in range(len(train_df)):
img_path = os.path.join('data/data10954/cat_12_train', train_df.at[i, 'name']) # i 元素在列中的位置 ,name 列名
if os.path.exists(img_path) and imghdr.what(img_path): # 检测路径文件是否存在及判断类别
img = Image.open(img_path) # 打开文件
if img.mode != 'RGB':
img = Image.open(img_path)
print(img_path)
print(img.mode)
img = img.convert('RGB') # 转换成rgb形式
img.save(img_path) # 保存
for img_path in os.listdir('data/data10954/cat_12_test'):
src = os.path.join('data/data10954/cat_12_test',img_path)
img = Image.open(src)
if img.mode != 'RGB':
print(img_path)
img = img.convert('RGB')
img.save(src)
data/data10954/cat_12_train/tO6cKGH8uPEayzmeZJ51Fdr2Tx3fBYSn.jpg P data/data10954/cat_12_train/ulFBEZNRQrxn57voHAJ4UG6Mct2sw1Cj.jpg RGBA data/data10954/cat_12_train/F3VnNwb2K9tgMWLodrXl1f6PIEjYqhy8.jpg L data/data10954/cat_12_train/YfsxcFB9D3LvkdQyiXlqnNZ4STwope2r.jpg P data/data10954/cat_12_train/6yYs4rvFLkQJlRxdhNfMOW52EAbgHejC.jpg RGBA data/data10954/cat_12_train/5nKsehtjrXCZqbAcSW13gxB8E6z2Luy7.jpg P data/data10954/cat_12_train/yGcJHV8Uuft6grFs7QWnK5CTAZvYzdDO.jpg P data/data10954/cat_12_train/YGyx4qCdOb7j8tzBuNfoFHLi6gU0SE3T.jpg RGBA data/data10954/cat_12_train/3yMZzWekKmuoGOF60ICQxldhBEc9Ra15.jpg P Qt29gPjYZwv3B6RJh5yiTWXrVImue1FH.jpg
运行时长:2.46秒结束时间:2023-04-29 13:19:48
(5)、数据可视化
【 查看同一类猫咪的特征 】
## Data Visualization
## 随机查看同一类猫咪的特征
plt.figure(1)
img_1_1 = Image.open('data/data10954/cat_12_train/spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg')
plt.subplot(2, 2, 1)
plt.imshow(img_1_1)
plt.subplot(2, 2, 2)
img_1_2 = Image.open('data/data10954/cat_12_train/7QZTYlspK2fqdJUwjC0HDmOFrM5W4PX9.jpg')
plt.imshow(img_1_2)
plt.subplot(2, 2, 3)
img_1_3 = Image.open('data/data10954/cat_12_train/oZin4PuwTet39xWCYhUBfvlzGyISb5DV.jpg')
plt.imshow(img_1_3)
plt.subplot(2, 2, 4)
img_1_4 = Image.open('data/data10954/cat_12_train/qbKjsR05lrFVYfLChtMGD7im36cUgAnE.jpg')
plt.imshow(img_1_4)
【 选取不同类别的猫咪进行查看 】
## 随机选取不同类别的猫咪进行查看
plt.figure(2)
img_0 = Image.open('data/data10954/cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg')
plt.subplot(2, 6, 1)
plt.imshow(img_0)
img_0 = Image.open('data/data10954/cat_12_train/spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg')
plt.subplot(2, 6, 2)
plt.imshow(img_0)
img_0 = Image.open('data/data10954/cat_12_train/jbIdxGyNpoql3XQZrfREMiAzh7B46WOa.jpg')
plt.subplot(2, 6, 3)
plt.imshow(img_0)
img_0 = Image.open('data/data10954/cat_12_train/cCeBo4EJ9H1hbXsIS5G6Kxdzg27nwqfy.jpg')
plt.subplot(2, 6, 4)
plt.imshow(img_0)
img_0 = Image.open('data/data10954/cat_12_train/yxNcRSz4TI7FpwCVJBuea6MmGitZYUkK.jpg')
plt.subplot(2, 6, 5)
plt.imshow(img_0)
【 统计训练集各类猫的数目 】
检验项目样本是否不均匀,我们对data/cat_12_train中各类猫的图片数量进行统计并绘制条形图,
统计训练集各类猫的数目,防止样本不平衡问题。
## 统计训练集各类猫的数目,防止样本不平衡问题。
from pyecharts import options as opts
from pyecharts.charts import Bar
with open("data/data10954/train_list.txt", "r") as f:
labels = f.readlines()
labels = [int(i.split()[-1]) for i in labels]
counts = pd.Series(labels).value_counts().sort_index().to_list()
values = np.random.rand(12) * 100
names = [str(i) for i in list(range(12))]
data = list(zip(values, counts, names))
source = [list(i) for i in data]
source.insert(0, ["score", "amount", "product"])
c = (
Bar()
.add_dataset(
source=source
)
.add_yaxis(
series_name="",
y_axis=[],
encode={"x": "amount", "y": "product"},
label_opts=opts.LabelOpts(is_show=False),
)
.set_global_opts(
title_opts=opts.TitleOpts(title="Dataset normal bar example"),
xaxis_opts=opts.AxisOpts(name="amount"),
yaxis_opts=opts.AxisOpts(type_="category"),
visualmap_opts=opts.VisualMapOpts(
orient="horizontal",
pos_left="center",
min_=10,
max_=100,
range_text=["High Score", "Low Score"],
dimension=0,
range_color=["#D7DA8B", "#E15457"],
),
)
.render("./work/labels.html")
运行时长:121毫秒结束时间:2023-04-29 13:20:22
从源路径 src_path 移动至目标路径 dst_path。
## 从源路径 src_path 移动至目标路径 dst_path。
for i in range(len(train_df)):
# 源路径
src_path = os.path.join('data/data10954/cat_12_train',train_df.at[i, 'name']) # i 元素在列中的位置 ,name 列名
# 目标路径
dst_path = os.path.join(os.path.join('data/data10954/ImageNetDataset/',train_df.at[i, 'cls']),train_df.at[i, 'name'])
try:
shutil.move(src_path, dst_path) # 移动图片到目标路径
except Exception as e:
print(e) # 抛出错误信息
运行时长:95毫秒结束时间:2023-04-29 13:20:28
使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(二)
推荐阅读:
给照片换底色(python+opencv) | ||
计算机视觉__基本图像操作(显示、读取、保存) | 直方图(颜色直方图、灰度直方图) | 直方图均衡化(调节图像亮度、对比度) |
语音识别实战(python代码)(一) | 人工智能基础篇 | 计算机视觉基础__图像特征 |
matplotlib 自带绘图样式效果展示速查(28种,全) | ||
Three.js实例详解___旋转的精灵女孩(附完整代码和资源)(一) | ||
| | |
立体多层玫瑰绘图源码__玫瑰花python 绘图源码集锦 | Python 3D可视化(一) | 让你的作品更出色——词云Word Cloud的制作方法(基于python,WordCloud,stylecloud) |
| | |
python Format()函数的用法___实例详解(一)(全,例多)___各种格式化替换,format对齐打印 | 用代码写出浪漫__合集(python、matplotlib、Matlab、java绘制爱心、玫瑰花、前端特效玫瑰、爱心) | python爱心源代码集锦(18款) |
| | |
Python中Print()函数的用法___实例详解(全,例多) | Python函数方法实例详解全集(更新中...) | 《 Python List 列表全实例详解系列(一)》__系列总目录、列表概念 |
| | |
用代码过中秋,python海龟月饼你要不要尝一口? | python练习题目录 | |
| | |
草莓熊python turtle绘图(风车版)附源代码 | 草莓熊python turtle绘图代码(玫瑰花版)附源代码 | 草莓熊python绘图(春节版,圣诞倒数雪花版)附源代码 |
| | |
巴斯光年python turtle绘图__附源代码 | 皮卡丘python turtle海龟绘图(电力球版)附源代码 | |
| | |
Node.js (v19.1.0npm 8.19.3) vue.js安装配置教程(超详细) | 色彩颜色对照表(一)(16进制、RGB、CMYK、HSV、中英文名) | 2023年4月多家权威机构____编程语言排行榜__薪酬状况 |
| | |
手机屏幕坏了____怎么把里面的资料导出(18种方法) | 【CSDN云IDE】个人使用体验和建议(含超详细操作教程)(python、webGL方向) | 查看jdk安装路径,在windows上实现多个java jdk的共存解决办法,安装java19后终端乱码的解决 |
| ||
vue3 项目搭建教程(基于create-vue,vite,Vite + Vue) | ||
| | |
2023年春节祝福第二弹——送你一只守护兔,让它温暖每一个你【html5 css3】画会动的小兔子,炫酷充电,字体特 | 别具一格,原创唯美浪漫情人节表白专辑,(复制就可用)(html5,css3,svg)表白爱心代码(4套) | SVG实例详解系列(一)(svg概述、位图和矢量图区别(图解)、SVG应用实例) |
| | |
【程序人生】卡塔尔世界杯元素python海龟绘图(附源代码),世界杯主题前端特效5个(附源码) | HTML+CSS+svg绘制精美彩色闪灯圣诞树,HTML+CSS+Js实时新年时间倒数倒计时(附源代码) | 2023春节祝福系列第一弹(上)(放飞祈福孔明灯,祝福大家身体健康)(附完整源代码及资源免费下载) |
| | |
tomcat11、tomcat10 安装配置(Windows环境)(详细图文) | Tomcat端口配置(详细) | Tomcat 启动闪退问题解决集(八大类详细) |