文章目录
- 训练LeNet模型
- 下载FashionMNIST数据
- 训练
- 保存模型
- 卷积神经网络可视化
- 加载模型
- 一个测试图像
- 不同层对图像处理的可视化
- 第一个卷积层的处理
- 第二个卷积层的处理
卷积神经网络是利用图像空间结构的一种深度学习网络架构,图像在经过卷积层、激活层、池化层、全连接层等处理后得到输出。
本次想探索一下图像经过每一层都发生了什么变化,比如不同的卷积核(滤波器)都提取了图像的什么特征?越深层是否会对图像更抽象化?
带着这些问题,本文将使用FashionMNIST数据、简单的LeNet模型来探索CNN是如何处理图像的。
训练LeNet模型
首先来训练一个LeNet模型(换成其他卷积神经网络也可以),目的是为了利用训练好的模型参数获得输入图像的各层输出,以供可视化之用。
下载FashionMNIST数据
root:设置下载路径;
train:为True表示下载训练集,反之为测试集;
download:首次下载设为True,下载好后可以改为False。
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from matplotlib import pyplot as plt
%matplotlib inline
from PIL import Image
import torch
from torch import nn
import torchvision
from torch.utils import data
from torchvision import transforms
#下载数据
def load_fashion_mnist(batch_size):
trans = transforms.Compose([transforms.ToTensor()])
train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=False)
test = torchvision.datasets.FashionMNIST(root="../data", train=False,transform=trans,download=False)
return (data.DataLoader(train, batch_size, shuffle=True), data.DataLoader(test, batch_size, shuffle=False))
训练
简单地训练网络
#批量大小
batch_size = 512
train_iter, test_iter = load_fashion_mnist(batch_size=batch_size)
#LeNet网络
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
nn.Linear(120, 84), nn.ReLU(),
nn.Linear(84, 10))
#参数初始化
def init_network(model, method='xavier'):
for name, w in model.named_parameters():
if 'weight' in name:
if method == 'xavier':
nn.init.xavier_normal_(w)
elif method == 'kaiming':
nn.init.kaiming_normal_(w)
else:
nn.init.normal_(w)
elif 'bias' in name:
nn.init.constant_(w, 0)
else:
pass
init_network(net)
#损失函数
loss=nn.CrossEntropyLoss()
#优化算法
lr=0.05
updater=torch.optim.SGD(net.parameters(),lr=lr)
#训练
def train(net, train_iter, test_iter, loss, num_epochs, updater,device):
net.to(device)
for epoch in range(num_epochs):
if isinstance(net, torch.nn.Module):
net.train()
for X, y in train_iter:
X,y=X.to(device),y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
updater.zero_grad()
l.backward()
updater.step()
device = torch.device("mps" if torch.backends.mps.is_available else "cpu") #Mac使用mps
num_epochs = 20
train(net, train_iter, test_iter, loss, num_epochs, updater,device)
保存模型
#保存模型参数
torch.save(net.state_dict(),'LeNet.params')
卷积神经网络可视化
本节将使用上文训练好的模型来可视化卷积神经网络不同层对图像的处理过程。
加载模型
#LeNet网络结构
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
nn.Linear(120, 84), nn.ReLU(),
nn.Linear(84, 10))
#加载模型参数
net.load_state_dict(torch.load('LeNet.params'))
net.eval()
一个测试图像
#批量大小
batch_size = 1
train_iter, test_iter = load_fashion_mnist(batch_size=batch_size)
#随机选择一个图像
for x,y in train_iter:
_=plt.imshow(x.squeeze(0).permute(1,2,0).numpy())
break
看上去我们抽到了一件T恤。
不同层对图像处理的可视化
#计算到给定层的输出
def cnn_net(X,net,l=1):
for i,layer in enumerate(net[0:l]):
X=layer(X)
if i==l-1:
print('第%s层:%-10s 输出形状:%s'%(i+1, layer.__class__.__name__, X.shape))
return X
#可视化
def cnn_visual(imgs,nrows,ncols,scale):
figsize = (ncols * scale, nrows * scale)
fig,axes = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs.squeeze(0))):
_ = ax.imshow(img.detach().numpy())
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
return axes
for l in [1,2,3,4,5,6]:
imgs=cnn_net(x.clone(),net,l=l)
nrows=2
ncols=int(imgs.shape[1]/nrows)
axes=cnn_visual(imgs,nrows,ncols,2)
第一个卷积层的处理
我们先来看看第一个卷积层中不同卷积核分别从图像中提取了什么信息,第一个卷积层有6个输出通道,因此查看每个通道输出的图像。
从下图可以看出,第一个卷积层提取到了不同轮廓层次信息。
经过ReLU处理后:
再经平均池化处理后,变化不大:
第二个卷积层的处理
第二个卷积层有16个输出通道,随着层次加深,感受野扩大,通道的融合后,从下图看已经比较抽象了,但隐隐约约还能看出点端倪:
再经ReLU和池化处理后,基本上已经面目全非:
以上就是对卷积神经网络可视化的初步探索,感兴趣的读者可以在不同卷积神经网络和图像上多做尝试。
另附一个卷积神经网络可视化网站