前言
参加了华为一个小比赛第四届MindCon-爱(AI)美食–10类常见美食图片分类,本来想实践机器学习课程的知识,后来发现图像分类任务基本都是用神经网络做,之前在兴趣课上学过一点神经网络但不多,通过这样一个完整的项目也算入门了。
代码仓库:https://github.com/fgmn/ResNet
任务
ResNet
这里主要结合官方pytorch代码和B站视频6.2 使用pytorch搭建ResNet并基于迁移学习训练进行理解。
模型
层数不同的网络许多子结构是相似的,因此对子结构的定义会有一些参数定义。
论文提到两种残差结构,从上面表格可以看到,左侧building block用于18,34层网络,右侧bottleneck用于50,101,152层网络。
左侧残差结构的实现如下,首先定义残差结构所使用的一系列层结构,stride=1
时输入输出矩阵大小相同,stride=2
时输出长宽均为输入的
1
2
\frac{1}{2}
21,channel
是通道数,和卷积核个数对应,如
3
×
3
,
64
3\times3,64
3×3,64代表使用64个大小为
3
×
3
3\times3
3×3的卷积核对输入的64个通道进行卷积运算。
之后定义正向传播过程,实际定义了网络结构,bn
层定义在卷积层和激活函数之间。
class BasicBlock(nn.Module):
expansion = 1 # 用于协调残差结构卷积核个数发生变化
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
# 输入参数:输入特征矩阵深度,输出特征矩阵深度,卷积核移动步长,下采样参数(对应虚线残差结构)
# 定义残差结构所使用的一系列层结构
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
# 正向传播过程
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
同理,定义右侧残差结构,之后定义ResNet如下:
class ResNet(nn.Module):
def __init__(self,
block, # 残差结构:BasicBlock/Bottleneck
blocks_num, # 残差结构数
num_classes=1000,
include_top=True,
groups=1,
width_per_group=64):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.groups = groups
self.width_per_group = width_per_group
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self, block, channel, block_num, stride=1):
# 构建层结构
downsample = None
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
# 虚线残差结构
layers.append(block(self.in_channel,
channel,
downsample=downsample,
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group))
self.in_channel = channel * block.expansion
# 实线残差结构
for _ in range(1, block_num):
layers.append(block(self.in_channel,
channel,
groups=self.groups,
width_per_group=self.width_per_group))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x) # 全连接层
return x
定义具体网络:
def resnet34(num_classes=1000, include_top=True):
# 迁移学习,预训练模型路径
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet50(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnet50-19c8e357.pth
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet101(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
# more ...
训练
尝试使用cuda,
# 指定训练使用设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
定义数据的transform,进行随机裁剪,随机翻转,标准化处理等等操作,
# 图像处理
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机翻转
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), # 标准化处理
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
加载训练集以及验证集,施加transform,定义batch_size=16
,
# ---------------------------- 数据集加载----------------------------------------
data_root = os.path.abspath(os.path.join(os.getcwd())) # "../.."返回上上层目录 get data root path
image_path = os.path.join(data_root, "data_set", "food_data") # food data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
food_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in food_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=10)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw) # 加载数据使用线程个数
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
实例化一个34层网络,加载预训练模型,基于迁移学习方法训练,
net = resnet34()
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location='cpu')) # 载入模型权重
定义全连接层,交叉熵损失函数,以及Adam优化器,
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 10)
net.to(device)
# define loss function
loss_function = nn.CrossEntropyLoss() # 针对多类别,损失交叉熵函数
# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
训练并验证,保存效果最好的网络,
epochs = 20
best_acc = 0.0
save_path = './resNet34.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train() # 可管理batchnorm层以及dropout方法
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad() # 清空之前的梯度信息
logits = net(images.to(device)) # 正向传播
loss = loss_function(logits, labels.to(device))
loss.backward() # 反向传播
optimizer.step() # 更新每个节点参数
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad(): # 在验证过程中不计算损失梯度
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # to(device)在设备中可能有缓存
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc: # 保存最优网络
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')