本节使用Vision Transfomer完成图像分类
相关知识
Vision Transformer
ViT是计算机视觉和自然语言处理两个领域的融合成果。它使用transformer架构来处理图像数据,这种架构原本是用于处理自然语言的。
ViT的主要思想是将图像分割成固定大小的块(patches),再将块转换成序列形式(自然语言中就是词语序列)。ViT就是利用transformer的编码器来处理这些序列的。
此外ViT引入了visual token视觉标记的概念,它通过在输入的图像块中添加一些特定的位置编码信息,使模型可以处理不同位置和内容的图像信息。
ViT模型的特点包括
- 将图像划分成多个patch,再将二维的patch转换为一维向量,加上类别向量与位置向量作为模型输入。
- 模型主体block是基于transformer的encoder结构(调整了normalization的位置),核心组件依然是Multi head attention。
- transformer encoder部分称为backbone,最后的全连接层称为head。
transformer基本原理
这里着重介绍一下编解码器结构。encoder与decoder由许多结构组成,如:多头注意力层、feed forward层、normalization层、残差连接residual connection。其中,最重要的是多头注意力,该结构基于自注意力机制,是多个self attention的并行组成。
attention
self attention的核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询Query向量,计算Query和各个key的相似性或相关性得到注意力分布,得到每个key对应value的权重系数,再对value加权求和得到最终的attention数值。
输入向量通过embedding层映射成并行的Q、K、V三个向量。再使用QK做点积得到注意力权重,再将该权重对value加权求和,得到最终的输出向量。
将q与每个k做点积
再与V加权求和得到最终的结果。
自注意力机制主要体现在QKV都源于它本身,他们提取的是不同顺序的输入向量之间的联系与特征,再通过不同顺序向量之间的联系紧密性表现出来。
多头注意力
将原本的自注意力处理的向量分割成多个头处理。每个头都有自己独立的线性变换矩阵来并行处理输入序列。首先将输入序列通过多个独立的线性变换矩阵,生成h组QKV、再对每一组都进行注意力计算,得到h个新的表示。接着把所有的输出拼接到一起,得到一个新的向量。最后对拼接后的向量进行线性变换得到多头注意力的输出。
实验部分
数据加载
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
# 一系列数据增强操作
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
实现注意力机制
class Attention(nn.Cell):
# 参数:输入和输出向量的维度、注意力头数、输出保留的概率、注意力矩阵保留的概率。
def __init__(self,
dim: int,
num_heads: int = 8,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__()
self.num_heads = num_heads
# 计算头的维度
head_dim = dim // num_heads
self.scale = ms.Tensor(head_dim ** -0.5)
# 将输入映射到qkv的线性层
self.qkv = nn.Dense(dim, dim * 3)
self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
# 输出线性层
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(p=1.0-keep_prob)
# 矩阵乘法:注意力矩阵乘v
self.attn_matmul_v = ops.BatchMatMul()
# 矩阵乘法:q乘k
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
def construct(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = ops.unstack(qkv, axis=0)
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = ops.transpose(out, (0, 2, 1, 3))
out = ops.reshape(out, (b, n, c))
out = self.out(out)
out = self.out_drop(out)
return out
构建前馈神经网络和残差连接
class FeedForward(nn.Cell):
# 参数:输入特征的维度、隐藏层特征的维度、输出特征的维度、激活函数、保留概率
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
activation: nn.Cell = nn.GELU,
keep_prob: float = 1.0):
super(FeedForward, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
# 全连接层1
self.dense1 = nn.Dense(in_features, hidden_features)
self.activation = activation()
# 全连接层2
self.dense2 = nn.Dense(hidden_features, out_features)
self.dropout = nn.Dropout(p=1.0-keep_prob)
def construct(self, x):
"""Feed Forward construct."""
x = self.dense1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.dense2(x)
x = self.dropout(x)
return x
class ResidualCell(nn.Cell):
def __init__(self, cell):
super(ResidualCell, self).__init__()
self.cell = cell
def construct(self, x):
# 将输入直接添加到输出
return self.cell(x) + x
构建encoder
class TransformerEncoder(nn.Cell):
def __init__(self,
dim: int,
num_layers: int,
num_heads: int,
mlp_dim: int,
keep_prob: float = 1.,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
# 使用刚刚定义的注意力类和前馈网络类构建encoder
for _ in range(num_layers):
# 包含两个归一化层、一个注意力层、一个前馈神经网络层
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim=dim,
num_heads=num_heads,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation=activation,
keep_prob=keep_prob)
# 将这些层用残差连接包装起来
layers.append(
nn.SequentialCell([
ResidualCell(nn.SequentialCell([normalization1, attention])),
ResidualCell(nn.SequentialCell([normalization2, feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
如何将二维图片矩阵转化为一维词向量呢?
在Vit中,使用卷积将输入图像在每个channel上划分为16*16个patch,再将patch拉伸为一维向量。这样就接近词向量堆叠的效果。
class PatchEmbedding(nn.Cell):
MIN_NUM_PATCHES = 4
# 假设图像输入是24*24
def __init__(self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
input_channels: int = 3):
super(PatchEmbedding, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
def construct(self, x):
"""Path Embedding construct."""
x = self.conv(x)
b, c, h, w = x.shape
x = ops.reshape(x, (b, c, h * w))
x = ops.transpose(x, (0, 2, 1))
return x
划分为patch后,会经过两个过程
- class_embedding:在word vector之前增加一个类别标记,是一个可学习的参数
- pos_embedding:将位置信息加入到patch矩阵中,是一个可学习的参数
构建完整模型
def init(init_type, shape, dtype, name, requires_grad):
initial = initializer(init_type, shape, dtype).init_data()
return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
def __init__(self,
image_size: int = 224,
input_channels: int = 3,
patch_size: int = 16,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: Optional[nn.Cell] = nn.LayerNorm,
pool: str = 'cls') -> None:
super(ViT, self).__init__()
# 先转化patch
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim,
input_channels=input_channels)
num_patches = self.patch_embedding.num_patches
# 加入cls token
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
# 加入pos embedding
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
self.pool = pool
self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
self.norm = norm((embed_dim,))
self.transformer = TransformerEncoder(dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
mlp_dim=mlp_dim,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob,
drop_path_keep_prob=drop_path_keep_prob,
activation=activation,
norm=norm)
self.dropout = nn.Dropout(p=1.0-keep_prob)
self.dense = nn.Dense(embed_dim, num_classes)
def construct(self, x):
"""ViT construct."""
x = self.patch_embedding(x)
cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
x = ops.concat((cls_tokens, x), axis=1)
x += self.pos_embedding
x = self.pos_dropout(x)
x = self.transformer(x)
x = self.norm(x)
x = x[:, 0]
if self.training:
x = self.dropout(x)
x = self.dense(x)
return x
模型训练
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# construct model
network = ViT()
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# 学习率衰减函数
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.00005,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=10)
# 定义优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# 定义损失函数
class CrossEntropySmooth(LossBase):
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
# train model
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
dataset_sink_mode=False,)
模型评价
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
transforms.Decode(),
transforms.Resize(224 + 32),
transforms.CenterCrop(224),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# construct model
network = ViT()
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# 使用Top_1_Accuracy和Top_5_Accuracy评价
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
# evaluate model
result = model.eval(dataset_val)
print(result)
总结
本章使用ImageNet的数据集上完成了ViT模型的构建和推理,学习了ViT网络的构成。
打卡凭证
在这里插入图片描述