数据集结构
话不多说,直接上核心代码
myDataset.py
from collections import Counter
from torch.utils.data import Dataset
import os
from PIL import Image
class MyDataset(Dataset):
"""
读取自制的数据集
args:
- image_dir: 图片的地址
- label_dir: 标签的地址
- name: 数据集的名称
- transform: 数据集的预处理
"""
def __init__(self, image_dir:str, label_dir:str, name:str, transform=None):
self.img_dir = os.path.join(image_dir, name)
self.label_dir = os.path.join(label_dir, name)
self.name = name
self.image_path = os.listdir(self.img_dir)
self.label_path = os.listdir(self.label_dir)
self.transform = transform
"""
读取数据集
args:
- index: 数据集的索引
return:
- image: 图片
- label: 图片的标签
"""
def __getitem__(self, index:int)->tuple:
# 获取图片的地址
image = self.image_path[index]
image = os.path.join(self.img_dir, image)
# 获取图像
image = Image.open(image)
# 如果不是彩色图像,将下面的注释解开可以转换成彩色图像,不过图片的模样改变很大
# if image.mode!= 'RGB':
# image = image.convert('RGB')
# 获取label的地址
index_path = self.label_path[index]
index_path = os.path.join(self.label_dir, index_path)
label = self.parseTxt(index_path)
if self.transform is not None:
image = self.transform(image)
return image, label
"""
将txt文件解析成数字
description:
> 这里每个txt文件下可能有多个label,选出现最多的,如果你的txt里面只有一个label的话,想办法读取出来返回就行
args:
- label: txt文件的地址
return:
- label: 图片的标签
"""
def parseTxt(self, label:str)->int:
first_column = []
with open(label, 'r') as f:
for line in f.readlines():
first_column.append(int(line.split()[0]))
counter = Counter(first_column)
return counter.most_common(1)[0][0]
"""
获取数据集的长度
"""
def __len__(self)->int:
return len(self.image_path)
demo
train.py
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
# 导入加载数据集的类
from dataset import MyDataset
import os
root = os.path.join(os.getcwd(),'courseHomework','datasets')
transform = transforms.Compose([
transforms.Resize((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.5), (0.5,))
])
train_dataset = MyDataset(root + '/images', root +'/labels', 'train', transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False)
for step, data in enumerate(train_loader):
imgs, labels = data
print(imgs[0].shape)
transforms.ToPILImage()(imgs[0]).show()
break
大家结构和我不一样可以自由发挥