目录
- 图片分类任务方法概述
- 卷积神经网络(CNN)
- 视觉Transformer(ViT)
- 视觉图神经网络(ViG)
- ViG模型
- 模型架构
- 图像输入
- 图结构生成
- 网络模块
- 图处理
- 特征变换
- 多尺度处理
- 输出头
- ViG代码
- 模型主体架构设计
- 核心代码
- 演示效果
- 附件使用
- 参考文献:需要本文的详细复现过程的项目源码、数据和预训练好的模型可从该地址处获取完整版:地址
图片分类任务方法概述
卷积神经网络(CNN)
- 发展背景: CNN的出现标志着深度学习在图像识别领域的重大突破。最早的CNN模型可以追溯到1998年的LeNet,而2012年的AlexNet模型在ImageNet竞赛中取得优异成绩,使得CNN成为图像分类任务的主流方法。
分类方法优点:
- 局部感知野: 通过卷积操作,CNN能够捕捉图像的局部特征,减少参数数量。
- 参数共享: 卷积核在整张图像上共享,提高了模型的泛化能力。
- 平移不变性: CNN具有平移不变性,能够识别图像中的物体,即使它们的位置发生变化。
视觉Transformer(ViT)
- 发展背景: ViT于2020年被提出,借鉴了自然语言处理领域的Transformer架构,将自注意力机制应用于图像分类任务。
分类方法优点:
- 自注意力机制: 能够捕捉图像中的长距离依赖关系,提高分类准确性。
- 可扩展性: Transformer结构易于扩展,适用于大规模数据集。
- 并行计算: 自注意力机制使得ViT能够更好地利用并行计算资源。
视觉图神经网络(ViG)
- 发展背景: ViG的提出是为了解决CNN和ViT在处理不规则和复杂目标时的局限性。ViG将图像视为图结构,通过图卷积操作进行特征提取和分类。
分类方法优点:
- 灵活的图结构: ViG采用图结构表示图像,能够更好地处理不规则形状的物体,提高对复杂场景的识别能力。
- 图卷积操作: 通过图卷积,ViG能够有效地聚合和更新节点信息,捕捉局部和全局特征。
- 节点特征变换: FFN模块(多层感知器)用于节点特征变换,增强了模型的表达能力
ViG模型
图片切成patch
(a) Grid Structure
-
作用:
像素级信息捕获:通过将图像切分成均匀分布的小块(Patch),每个Patch代表图像的一个局部区域。
空间关系保持:保留了图像的空间布局信息,使得模型能够理解对象的位置和相对位置。 -
重要性:
经典方法的基础:这是许多传统计算机视觉算法的基本假设,包括早期的人工设计特征提取方法和现代的深度学习模型(如卷积神经网络CNN)。
简单直观:易于理解和实施,是初学者入门的好选择。
(b) Sequence Structure -
作用:
序列化处理:将图像的Patch按某种顺序排列,形成一维序列。
时间维度模拟:虽然实际处理的是静态图像,但通过序列化的方式,可以引入类似于自然语言处理(NLP)领域的时间维度概念。 -
重要性:
Transformer的应用:这种结构特别适合于基于Transformer架构的方法,如Vision Transformer(ViT)。ViT等模型通过自注意力机制对序列化的Patch进行处理,从而有效地捕捉全局上下文信息。
灵活性提升:相比固定大小的卷积核,序列化处理允许模型关注任意距离的Patch之间的关系,提高了模型的灵活性和泛化能力。
© Graph Structure -
作用:
非结构化数据建模:将图像中的Patch视为图中的节点,允许模型处理更加复杂和灵活的数据结构。
适应性强:能够更好地适应各种形状和尺寸的对象,尤其是对于那些不能很好地用网格或序列描述的情况。 -
重要性:
图神经网络优势:结合图神经网络(GNN)的优点,能够有效处理具有复杂拓扑结构的数据,如社交网络、分子结构等。
创新性突破:在视觉任务中引入图结构是一种创新尝试,有望带来新的突破和进展,特别是在需要精细分析和理解场景的情况下。
模型架构
图像输入
首先,从一张原始图像开始。在这个例子中,图像展示了一条鱼和一个人的部分身体。
图结构生成
接下来,将图像划分为若干个Patch,并将这些Patch作为图中的节点。每个节点代表图像的一部分,而边则表示这些部分之间的关联。红色圆圈内的节点可能表示图像的关键部分,比如鱼的身体或者人的衣服图案。
网络模块
然后,进入网络模块,该模块由两部分组成:图处理和特征变换。
图处理
在这一步骤中,模型会对图结构进行处理,以提取出各个Patch之间的关系和相互影响。这可以通过图卷积操作或其他类型的图神经网络技术完成。
特征变换
经过图处理之后,得到的特征会被送入特征变换模块。这里可能会涉及到一些标准的神经网络组件,如全连接层、激活函数等,目的是进一步提炼和转化所获得的信息。
多尺度处理
整个过程会重复多次(L次),每次都会产生一个新的特征图。这样做的好处是可以从不同的层次和角度来观察和理解图像内容,增强模型的表现力。
输出头
最后,所有经过多轮处理后的特征被整合起来,传递给输出头(Head for recognition)。这个输出头负责最终的识别任务,可能是分类、回归或者其他类型的问题。
ViG代码
PatchEmbedding
class Stem(nn.Module):
def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_dim, out_dim//8, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//8),
act_layer(act),
nn.Conv2d(out_dim//8, out_dim//4, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//4),
act_layer(act),
nn.Conv2d(out_dim//4, out_dim//2, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim//2),
act_layer(act),
nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(out_dim),
act_layer(act),
nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
nn.BatchNorm2d(out_dim),
)
def forward(self, x):
x = self.convs(x)
return x
模型主体架构设计
self.backbone = Seq(*[Seq(Grapher(channels, num_knn[i], 1, conv, act, norm,
bias, stochastic, epsilon, 1, drop_path=dpr[i]),
FFN(channels, channels * 4, act=act, drop_path=dpr[i])
) for i in range(self.n_blocks)])
核心代码
聚合特征
class MRConv2d(nn.Module):
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
super(MRConv2d, self).__init__()
self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)
def forward(self, x, edge_index, y=None):
print(x.shape, edge_index.shape)
x_i = batched_index_select(x, edge_index[1])
print(x_i.shape)
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
print(x_j.shape)
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
b, c, n, _ = x.shape
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
print(x.shape)
return self.nn(x)
演示效果
附件使用
- 安装相应依赖包
pip install -r requirements.txt
- 获取cifa10数据集
import torchvision
import torchvision.transforms as transforms
- transforms用于数据预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- 下载并加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 下载并加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
- CIFAR-10数据集中的类别
import torchvision
import torchvision.transforms as transforms
# transforms用于数据预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 下载并加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 下载并加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
# CIFAR-10数据集中的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
- 运行代码
python train.py