文章目录
- 网络结构
- 代码
- common_utils.py
- network.py
- provider.py
- train.py
- test.py
- visual.py
- 实验
- 训练结果
- 测试结果
- 可视化
网络结构
输入 | 过程 | 输出 |
---|---|---|
28*28 | Flatten | 784 |
784 | Linear | 300 |
300 | Linear | 100 |
100 | Linear | 10 |
代码
文件结构:
common_utils.py
用来输出日志文件
# common_utils.py
import logging
def create_logger(log_file=None, rank=0, log_level=logging.INFO):
logger = logging.getLogger(__name__)
logger.setLevel(log_level if rank == 0 else 'ERROR')
formatter = logging.Formatter('[%(asctime)s %(filename)s %(lineno)d '
'%(levelname)5s] %(message)s')
console = logging.StreamHandler()
console.setLevel(log_level if rank == 0 else 'ERROR')
console.setFormatter(formatter)
logger.addHandler(console)
if log_file is not None:
file_handler = logging.FileHandler(filename=log_file)
file_handler.setLevel(log_level if rank == 0 else 'ERROR')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
network.py
设计MLP结构,包含训练函数train_model和评估函数eval_model
# network.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import provider
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28*28, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, 10)
self.relu = nn.ReLU()
self.softmax = nn.LogSoftmax(dim=1)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc3(x)
x = self.softmax(x)
return x
def train_model(self, args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(self.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1) # 学习率调度器
train_loader = provider.GetLoader(batch_size=args.batch_size, loadType='train')
test_loader = provider.GetLoader(batch_size=args.batch_size, loadType='test')
best_accuracy = 0.0
for epoch in range(args.epochs):
self.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = self(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计准确率
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += loss.item()
train_loss = running_loss / len(train_loader)
train_accuracy = correct / total
# 在测试集上评估模型
self.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = self(images)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_loss += loss.item()
test_loss = test_loss / len(test_loader)
test_accuracy = correct / total
# 更新学习率
scheduler.step()
# 保存在验证集上表现最好的模型
if test_accuracy > best_accuracy:
best_accuracy = test_accuracy
torch.save({
'model_state_dict': self.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'best_accuracy': best_accuracy,
}, 'best_model.pth')
# 打印训练过程中的损失和准确率
args.logger.info(f"Epoch [{epoch+1}/{args.epochs}] - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")
# 保存最后一个epoch的模型
torch.save({
'model_state_dict': self.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'best_accuracy': best_accuracy,
}, 'final_model.pth')
def eval_model(self, dataloader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(device)
self.eval()
total = 0
correct = 0
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
outputs = self(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return accuracy
provider.py
包含数据读取函数GetLoader和数据可视化函数visualize_loader
# provider.py
from sklearn.preprocessing import MinMaxScaler
import torch
import torchvision
import matplotlib.pyplot as plt
def visualize_loader(loader,model=None):
# batch=[32*1*28*28,32]
for batch in loader:
break
fig, axes = plt.subplots(4, 8, figsize=(20, 10))
imgs=batch[0]
labels=batch[1].numpy()
if model==None:
imgName='train.png'
predicted=labels
else:
imgName = 'test.png'
outputs = model(imgs)
_, predicted = torch.max(outputs.data, 1)
predicted = predicted.numpy()
imgs=imgs.squeeze().numpy()
for i, ax in enumerate(axes.flat):
ax.imshow(imgs[i])
ax.set_title(predicted[i],color='black' if predicted[i]==labels[i] else 'red')
ax.axis('off')
plt.tight_layout()
plt.show()
plt.savefig(imgName)
# loader.shape=1875*[32*1*28*28,32]
def GetLoader(path='data',batch_size=32,loadType='train'):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
transfer=MinMaxScaler(feature_range=(0, 255))
dataset = torchvision.datasets.MNIST(root=path, train=loadType=='train',transform=transform,download =False)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
return loader
train.py
训练模型
# train.py
import argparse
import datetime
import common_utils
import os
import network
import provider
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--batch_size', type=int, default=32, required=False, help='batch size for training')
parser.add_argument('--epochs', type=int, default=7, required=False, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, required=False, help='learning rate')
log_file = 'output/'+ ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
logger = common_utils.create_logger(log_file)
parser.add_argument('--logger', type=type(logger), default=logger, help='logger')
args = parser.parse_args()
return args
def main():
args = parse_config()
# log to file
args.logger.info('**********************Start logging**********************')
for key, val in vars(args).items():
args.logger.info('{:16} {}'.format(key, val))
args.logger.info('**********************Start training ********************')
model = network.MLP()
model.train_model(args)
args.logger.info('**********************End training **********************')
# Evaluate the trained model
args.logger.info('**********************Start eval ************************')
test_loader = provider.GetLoader(batch_size=args.batch_size, loadType='test')
test_accuracy = model.eval_model(test_loader)
args.logger.info(f'Test Accuracy: {test_accuracy:.4f}')
args.logger.info('**********************End eval **************************')
args.logger.info('**********************End *******************************\n')
if __name__ == '__main__':
main()
test.py
测试模型
# test.py
import argparse
import datetime
import common_utils
import os
import network
import provider
import torch
import torch.nn as nn
import torch.optim as optim
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--batch_size', type=int, default=32, required=False, help='batch size for training')
parser.add_argument('--checkpoint', type=str, default='best_model.pth', help='checkpoint to start from')
log_file = 'output/'+ ('log_test_%s.txt' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
logger = common_utils.create_logger(log_file)
parser.add_argument('--logger', type=type(logger), default=logger, help='checkpoint to start from')
args = parser.parse_args()
return args
def main():
args= parse_config()
args.logger.info('**********************Start logging**********************')
for key, val in vars(args).items():
args.logger.info('{:16} {}'.format(key, val))
args.logger.info('**********************Start testing **********************')
test(args)
args.logger.info('**********************End testing ************************\n\n')
def test(args):
checkpoint = torch.load(args.checkpoint)
model=network.MLP()
model.load_state_dict(checkpoint['model_state_dict'])
args.logger.info(model)
test_loader=provider.GetLoader(batch_size=args.batch_size,loadType='test')
test_accuracy = model.eval_model(test_loader)
args.logger.info(f'Test Accuracy: {test_accuracy:.4f}')
if __name__ == '__main__':
main()
visual.py
可视化代码
# visual.py
import provider
import network
import torch
train_loader=provider.GetLoader(loadType='train')
provider.visualize_loader(train_loader)
test_loader=provider.GetLoader(loadType='test')
checkpoint = torch.load('best_model.pth')
model=network.MLP()
model.load_state_dict(checkpoint['model_state_dict'])
provider.visualize_loader(test_loader,model)
实验
训练结果
[2023-07-22 10:45:31,237 train.py 30 INFO] **********************Start logging**********************
[2023-07-22 10:45:31,237 train.py 32 INFO] batch_size 32
[2023-07-22 10:45:31,237 train.py 32 INFO] epochs 7
[2023-07-22 10:45:31,237 train.py 32 INFO] lr 0.01
[2023-07-22 10:45:31,237 train.py 32 INFO] logger <Logger common_utils (INFO)>
[2023-07-22 10:45:31,237 train.py 34 INFO] **********************Start training ********************
[2023-07-22 10:45:46,963 network.py 106 INFO] Epoch [1/7] - Train Loss: 0.5768, Train Accuracy: 0.8446, Test Accuracy: 0.9037
[2023-07-22 10:45:59,299 network.py 106 INFO] Epoch [2/7] - Train Loss: 0.5059, Train Accuracy: 0.8759, Test Accuracy: 0.9299
[2023-07-22 10:46:11,687 network.py 106 INFO] Epoch [3/7] - Train Loss: 0.4536, Train Accuracy: 0.8884, Test Accuracy: 0.9198
[2023-07-22 10:46:24,010 network.py 106 INFO] Epoch [4/7] - Train Loss: 0.3161, Train Accuracy: 0.9196, Test Accuracy: 0.9502
[2023-07-22 10:46:36,307 network.py 106 INFO] Epoch [5/7] - Train Loss: 0.2497, Train Accuracy: 0.9350, Test Accuracy: 0.9528
[2023-07-22 10:46:48,712 network.py 106 INFO] Epoch [6/7] - Train Loss: 0.2280, Train Accuracy: 0.9395, Test Accuracy: 0.9549
[2023-07-22 10:47:01,138 network.py 106 INFO] Epoch [7/7] - Train Loss: 0.2078, Train Accuracy: 0.9443, Test Accuracy: 0.9573
[2023-07-22 10:47:01,155 train.py 37 INFO] **********************End training **********************
[2023-07-22 10:47:01,155 train.py 40 INFO] **********************Start eval ************************
[2023-07-22 10:47:02,492 train.py 43 INFO] Test Accuracy: 0.9573
[2023-07-22 10:47:02,493 train.py 44 INFO] **********************End eval **************************
[2023-07-22 10:47:02,493 train.py 45 INFO] **********************End *******************************
测试结果
[2023-07-22 10:50:46,173 test.py 24 INFO] **********************Start logging**********************
[2023-07-22 10:50:46,173 test.py 26 INFO] batch_size 32
[2023-07-22 10:50:46,173 test.py 26 INFO] checkpoint best_model.pth
[2023-07-22 10:50:46,173 test.py 26 INFO] logger <Logger common_utils (INFO)>
[2023-07-22 10:50:46,173 test.py 27 INFO] **********************Start testing **********************
[2023-07-22 10:50:49,084 test.py 36 INFO] MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(fc1): Linear(in_features=784, out_features=300, bias=True)
(fc2): Linear(in_features=300, out_features=100, bias=True)
(fc3): Linear(in_features=100, out_features=10, bias=True)
(relu): ReLU()
(softmax): LogSoftmax(dim=1)
(dropout): Dropout(p=0.2, inplace=False)
)
[2023-07-22 10:50:50,970 test.py 39 INFO] Test Accuracy: 0.9573
[2023-07-22 10:50:50,970 test.py 29 INFO] **********************End testing ************************
可视化
测试结果