PyTorch 读取图片,主要是通过 Dataset 类
,所以先简单了解一下
Dataset
类。
Dataset
类作为所有的
datasets
的基类存在,所有的
datasets
都需要继承它,类似于
C++
中的虚基
类。
这里重点看
getitem
函数,
getitem
接收一个
index
,然后返回图片数据和标签,这个
index
通常指的是一个
list
的
index
,这个
list
的每个元素就包含了图片数据的路径和标签信
息。
然而,如何制作这个
list
呢,通常的方法是将图片的路径和标签信息存储在一个
txt
中,然后从该
txt
中读取。
那么读取自己数据的基本流程就是:
1.
制作存储了图片的路径和标签信息的
txt
2.
将这些信息转化为
list
,该
list
每一个元素对应一个样本
3.
通过
getitem
函数,读取数据和标签,并返回数据和标签
因此,要让
PyTorch
能读取自己的数据集,只需要两步:
1.
制作图片数据的索引
2.
构建
Dataset
子类
1.生成记事本代码
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
'''
为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
f = open(txt_path, 'w')
for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 获取 train文件下各文件夹名称
for sub_dir in s_dirs:
i_dir = os.path.join(root, sub_dir) # 获取各类的文件夹 绝对路径
img_list = os.listdir(i_dir) # 获取类别文件夹下所有png图片的路径
for i in range(len(img_list)):
if not img_list[i].endswith('png'): # 若不是png文件,跳过
continue
label = img_list[i].split('_')[0]
img_path = os.path.join(i_dir, img_list[i])
line = img_path + ' ' + label + '\n'
f.write(line)
f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)
2.效果
3.Dataset类代码
class MyDataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
words = line.split() #默认以空格为分隔符
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
# transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
#可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
#在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
#最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
#并不会生成新的一份图片,而是“覆盖”原图
self.target_transform = target_transform
self.transform = transform
def __getitem__(self, index):
fn, label = self.imgs[index]
#对图片进行读取
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
4.dataload
当
Mydataset
构建好,剩下的操作就交给
DataLoder
,在
DataLoder
中,会触发
Mydataset
中的
getiterm
函数读取一张图片的数据和标签,并拼接成一个
batch
返回,作为
模型真正的输入。