深度学习——划分自定义数据集
以人脸表情数据集raf_db为例,初始目录如下:
需要经过处理后返回
train_images, train_label, val_images, val_label
定义 read_split_data(root: str, val_rate: float = 0.2)
方法来解决,代码如下:
# root:数据集所在路径
# val_rate:划分测试集的比例
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍历文件夹,一个文件夹对应一个类别
file_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证各平台顺序一致
file_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(file_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images = [] # 存储训练集的所有图片路径
train_label = [] # 存储训练集图片对应索引信息
val_images = [] # 存储验证集的所有图片路径
val_label = [] # 存储验证集图片对应索引信息
every_class_num = [] # 存储每个类别的样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in file_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 排序,保证各平台顺序一致
images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
every_class_num.append(len(images))
# 按比例随机采样验证样本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
val_images.append(img_path)
val_label.append(image_class)
else: # 否则存入训练集
train_images.append(img_path)
train_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images)))
print("{} images for validation.".format(len(val_images)))
assert len(train_images) > 0, "number of training images must greater than 0."
assert len(val_images) > 0, "number of validation images must greater than 0."
return train_images, train_label, val_images, val_label
此时可通过以下代码获得训练集和测试集数据:
train_images, train_label, val_images, val_label = read_split_data(data_path)
完结撒花。