随着深度学习的不断发展,神经网络的层数越来越深。然而,网络层数的增加并不总是带来性能的提升,反而可能导致梯度消失或梯度爆炸等问题。为了解决这些问题,何恺明等人在2015年提出了残差网络(ResNet),这在深度学习领域引起了革命性的突破。
什么是ResNet?
ResNet,全称为Residual Network,即残差网络。它的核心思想是通过引入残差块(Residual Block),使得信息可以在网络中跨层传播,从而减轻深度网络训练中的退化问题。
传统的深度神经网络试图直接学习输入到输出的映射,而ResNet引入了恒等映射(Identity Mapping),允许网络只需学习输入与输出之间的残差(Residual)。这种方式大大降低了训练难度,使得网络可以堆叠数百甚至上千层。
ResNet的网络架构
ResNet的关键组件是残差块,其主要特征是引入了快捷链接(Shortcut Connection),也称为跳跃连接(Skip Connection)。这种连接方式允许输入信息直接传递到后面的层,形成一个恒等映射。
残差块结构
下面是ResNet的基本残差块结构:
在这个结构中,输入 x x x经过一系列卷积层得到输出 F ( x ) F(x) F(x),然后直接将输入 x x x加到 F ( x ) F(x) F(x)上,得到最终的输出 y y y: y = F ( x ) + x y=F(x)+x y=F(x)+x
这种结构的优势在于,如果 F ( x ) F(x) F(x)学习到的参数为零,那么输出 y y y就等于输入 x x x,网络就退化为恒等映射。这使得深层网络的训练变得更加稳定。
ResNet的使用
在深度学习领域,我们主要使用Python语言配合深度学习框架(如PyTorch或TensorFlow)来构建和使用ResNet。下面我们以PyTorch为例,展示如何使用预训练的ResNet模型。
安装PyTorch
首先,要确保已安装PyTorch:
# bash
pip install torch torchvision
导入必要的库
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
加载预训练的ResNet模型
# 加载ResNet-18模型
resnet18 = models.resnet18(pretrained=True)
图像预处理
# 定义图像预处理步骤
preprocess = transforms.Compose([
transforms.Resize(256), # 调整图像尺寸
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
加载和预处理图像
# 加载图像
img = Image.open("path_to_image.jpg") # 替换为实际的图像路径
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0) # 增加批次维度
进行推理
# 设置模型为评估模式
resnet18.eval()
# 前向传播
out = resnet18(batch_t)
# 获取预测结果
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
打印预测结果
# 读取ImageNet的类别标签
with open("imagenet_classes.txt") as f:
labels = [line.strip() for line in f.readlines()]
print(f"预测类别:{labels[index[0]]}, 概率:{percentage[index[0]].item():.2f}%")
Demo:使用ResNet进行图像分类
下面,我们将完整地运行一个小Demo,使用ResNet-18对一张图像进行分类。
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练模型
resnet18 = models.resnet18(pretrained=True)
resnet18.eval()
# 图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# 加载图像
img = Image.open("path_to_image.jpg")
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
# 前向传播
out = resnet18(batch_t)
# 获取预测结果
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
# 读取ImageNet的类别标签
with open("imagenet_classes.txt") as f:
labels = [line.strip() for line in f.readlines()]
print(f"预测类别:{labels[index[0]]}, 概率:{percentage[index[0]].item():.2f}%")
注意:请确保在同一目录下有imagenet_classes.txt
文件,其中包含ImageNet的1000个类别(可以关注公众号非鱼AI视界
回复imagenet_classes
领取)。
总结
ResNet的引入极大地推动了深度学习的发展,使得训练超深层的神经网络成为可能,在实际应用中,ResNet已经成为计算机视觉领域的景点模型,被广泛应用图像分类、目标检测等任务。
参考文献
- K. He, X. Zhang, S. Ren and J. Sun, “Deep Residual Learning for Image Recognition,” 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Las Vegas, NV, USA, 2016, pp. 770-778, doi: 10.1109/CVPR.2016.90.
- PyTorch官方文档
- Torchvision