前言:
这里结合一个ResNet-18 网络,讲解一下自己定义一个深度学习网络的完整流程。
经过20轮的训练,测试集上面的精度85%
一 残差块定义
针对图像处理有两种结构,下面代码左右实现的是左边的结构.
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 12:00:57 2023
@author: chengxf2
"""
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
resnet block
"""
def __init__(self, in_ch, out_ch, step):
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(in_channels = in_ch,
out_channels = out_ch,
kernel_size =3,
stride =step,
padding=1)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(in_channels = out_ch,
out_channels = out_ch,
kernel_size =3,
stride =1,
padding=1)
self.bn2 = nn.BatchNorm2d(out_ch)
self.extra = nn.Sequential()
#残差块部分
if in_ch != out_ch:
self.extra = nn.Sequential(
#[b,in_ch, h,w]=>[b, out_ch, h,w]
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride = step),
nn.BatchNorm2d(out_ch)
)
def forward(self,x):
"""
param x: [b ,ch, h,w]
return
"""
print(x.shape)
conv = self.conv1(x)
bn1 = self.bn1(conv)
out = F.relu(bn1)
conv = self.conv2(out)
bn2 = self.bn2(conv)
out = F.relu(bn2)
out = self.extra(x)+out
out = F.relu(out)
return out
二 定义网络
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 14:22:34 2023
@author: chengxf2
"""
import torch
from torch import nn
from torch.nn import functional as F
from ResBlock import ResBlk
class ResNet18(nn.Module):
def __init__(self, num_class):
super(ResNet18, self).__init__()
conv = nn.Conv2d(in_channels = 3,
out_channels = 16,
kernel_size =3,
stride =2,
padding=0)
bn = nn.BatchNorm2d(16)
self.conv1 = nn.Sequential(conv, bn)
#followed 4 blocks
#[b,16,h,w]=>[b,32,h,w]
self.blk1 = ResBlk(16, 32, 3)
#[b,16,h,w]=>[b,32,h,w]
self.blk2 = ResBlk(32, 64, 3)
#[b,16,h,w]=>[b,32,h,w]
self.blk3 = ResBlk(64, 128, 3)
#[b,16,h,w]=>[b,32,h,w]
self.blk4 = ResBlk(128, 256, 3)
self.fc = nn.Linear(256*2*2, num_class)
def forward(self, x):
a = self.conv1(x)
a = F.relu(a)
print("\n a ",a.shape)
a = self.blk1(a)
a = self.blk2(a)
a = self.blk3(a)
a = self.blk4(a)
#print(x.shape)
print("\n fc a: ",a.shape)
a = a.view(a.size(0),-1) #Flatten
y = self.fc(a)
return y
def main():
blk = ResBlk(64, 128,2)
#tmp: [batch, channel, width, height]
tmp = torch.randn(2,64,224,224)
out = blk(tmp)
print("\n resBlock: ",out.shape)
model =ResNet18(5)
tmp = torch.randn(2,3,224,224)
out = model(tmp)
print("resnet-18 ",out.shape)
#numbel是指tensor占用内存的数量
mp =map(lambda p:p.numel(), model.parameters())
sz = sum(mp)
print("\n parameters size ",sz)
if __name__ == "__main__":
main()
三 Train& Test
逻辑如下:
先使用训练集数据训练
使用验证集数据过拟合检查,保存模型参数
加载模型参数,进行测试
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:28:13 2023
@author: chengxf2
"""
for epoch in range(epochs):
train(train_db)
if epoch %10 ==0:
val_acc = evaluate(val_db)
if val_ass is the best:
#报错模型参数,防止过拟合
save_ckpt()
if out_of_patience():
break
#加载模型参数
load_ckpt()
test_acc = evaluate(test_db)
四 训练,验证,测试部分完整代码
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023
@author: chengxf2
"""
import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from ResNet_18 import ResNet18
from PokeDataset import Pokemon
batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)
root ='pokemon'
resize =224
csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)
train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()
def evalute(model, loader):
total =len(loader.dataset)
correct =0
for x,y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
acc = correct/total
return acc
def main():
model = ResNet18(5).to(device)
optimizer = optim.Adam(model.parameters(),lr =lr)
criteon = nn.CrossEntropyLoss()
best_epoch=0,
best_acc=0
viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
viz.line([0],[-1],win='val_loss', opts =dict(title='val_acc'))
global_step =0
for epoch in range(epochs):
print("\n --main---: ",epoch)
for step, (x,y) in enumerate(train_loader):
#x:[b,3,224,224] y:[b]
x = x.to(device)
y = y.to(device)
#print("\n --x---: ",x.shape)
logits =model(x)
loss = criteon(logits, y)
#print("\n --loss---: ",loss.shape)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
global_step +=1
if epoch %2 ==0:
val_acc = evalute(model, val_loader)
if val_acc>best_acc:
best_acc = val_acc
best_epoch =epoch
torch.save(model.state_dict(),'best.mdl')
print("\n val_acc ",val_acc)
viz.line([val_acc],[global_step],win='val_loss',update='append')
print('\n best acc',best_acc, "best_epoch: ",best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt')
test_acc = evalute(model, test_loader)
print('\n test acc',test_acc)
if __name__ == "__main__":
main()