有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
train.py
import torch
import torchvision
import time
import argparse
import importlib
from loss import FocalLoss
def parse_args():
parser = argparse.ArgumentParser('training')
parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
parser.add_argument('--model', default='lenet', help='model name [default: mlp]')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
batch_size = args.batch_size
num_epochs = args.num_epochs
model = importlib.import_module('models.'+args.model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = model.net.to(device)
loss = torch.nn.CrossEntropyLoss()
if args.model == 'mlp':
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
else:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
train_path = r'./Datasets/mnist_png/training'
test_path = r'./Datasets/mnist_png/testing'
transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
if args.model == 'alexnet' or args.model == 'vgg':
transform_list.append(torchvision.transforms.Resize(size=224))
if args.model == 'googlenet' or args.model == 'resnet':
transform_list.append(torchvision.transforms.Resize(size=96))
transform = torchvision.transforms.Compose(transform_list)
train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
for epoch in range(num_epochs):
train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
for X, y in train_iter:
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l += l.cpu().item()
train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
m += y.shape[0]
batch_count += 1
with torch.no_grad():
for X, y in test_iter:
net.eval() # 评估模式
test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
net.train() # 改回训练模式
n += y.shape[0]
print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
torch.save(net, args.model+".pth")
test.py
import cv2
import torch
import argparse
import importlib
from pathlib import Path
import torchvision.transforms.functional
def parse_args():
parser = argparse.ArgumentParser('testing')
parser.add_argument('--model', default='lenet', help='model name [default: mlp]')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
model = importlib.import_module('models.' + args.model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = model.net.to(device)
net = torch.load(args.model+'.pth')
net.eval()
with torch.no_grad():
imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")
acc = 0
count = 0
for img_path in imgs_path:
img = cv2.imread(str(img_path), 0)
if args.model == 'alexnet' or args.model == 'vgg':
img = cv2.resize(img, (224,224))
if args.model == 'googlenet' or args.model == 'resnet':
img = cv2.resize(img, (96,96))
img_tensor = torchvision.transforms.functional.to_tensor(img)
img_tensor = torch.unsqueeze(img_tensor, 0)
#print(net(img_tensor.to(device)).argmax(dim=1).item())
if(net(img_tensor.to(device)).argmax(dim=1).item()==6):
acc += 1
count+=1
print(acc/count)
数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
训练结果:
epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s
测试结果:
0.9780793319415448
现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
训练结果:
epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s
测试结果:
0.04801670146137787
可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。
引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union
class FocalLoss(nn.Module):
def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):
"""
:param alpha: 每个类别的权重
:param gamma:
:param with_logits: 是否经过softmax或者sigmoid
"""
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)
self.smooth = 1e-8
self.with_logits = with_logits
def _binary_class(self, input, target):
prob = torch.sigmoid(input) if self.with_logits else input
prob += self.smooth
alpha = self.alpha.to(target.device)
loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)
return loss
def _multiple_class(self, input, target):
prob = F.softmax(input, dim=1) if self.with_logits else input
alpha = self.alpha.to(target.device)
alpha = alpha.gather(0, target)
target = target.view(-1, 1)
prob = prob.gather(1, target).view(-1) + self.smooth # avoid nan
logpt = torch.log(prob)
loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
return loss
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
:param input: 维度为[bs, num_classes]
:param target: 维度为[bs]
:return:
"""
if len(input.shape) > 1 and input.shape[-1] != 1:
loss = self._multiple_class(input, target)
else:
loss = self._binary_class(input, target)
return loss.mean()
并将train.py的第26行修改成
loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])
其中列表的数字代表10个类别的权重值。
训练结果:
epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s
测试结果:
0.9196242171189979
对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。