本次项目是使用AlexNet实现5种花类的识别。
训练集搭建与LeNet大致代码差不多,但是也有许多新的内容和知识点。
1.导包,不必多说。
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
2.指定设备
device函数用来指定在训练过程中所使用的设备:如果有可用的GPU,那么使用第一块GPU,如果没有就默认使用cpu。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
3.数据预处理函数
单独定义出来,当key为“train”或为“val”时,返回数据集要使用的一系列预处理方法。
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224), # 把图片重新裁剪为224*224
transforms.RandomHorizontalFlip(), # 水平方向随机翻转
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)])}
4.获取数据集的路径
os.getcwd()方法获取当前文件所在的目录
os.path.join()方法将当前路径与上两级路径链接起来
image_path:获取到flower_data所在路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train", # 获取训练集的路径
transform=data_transform["train"]) # 训练预处理
train_num = len(train_dataset) # 打印训练集有多少张照片
5.加载数据集分类文件
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4} :数据集共分为五类
flower_list = train_dataset.class_to_idx 获取分类的名称所对应的索引值
cla_dict = dict((val, key) for key, val in flower_list.items()) 将字典中键与值的位置对换
?为什么要换位置
=>这样在预测后可以直接通过值给到我们最后的测试类别
json_str = json.dumps(cla_dict, indent=4) :将字典编码成json格式
with open('class_indices,json', 'w') as json_file:
json_file.write(json_str) :将键值对保存到json文件中,方便后续在预测时读取信息
下面是生成的json文件
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
json_file.write(json_str)
6.载入测试集
代码大致与LeNet网络差不多,载入测试集的图片路径需要自己定义并进行预处理。
在使用matplotlib查看图片时,注意修改为batch_size=4,shuffle=True参数。
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,
shuffle=False, num_workers=0)
暂时的全部代码,训练集还没有完全实现,我后续会补充上的,因为课真的是太多了。
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)])}
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4,
shuffle=True, num_workers=0)
学习碎碎念:
学习的道路上总会是遇到困难和麻烦的,不要心急,不要烦躁,一步一步的解决问题,慢慢来总会好的!