文章目录
- 一、简介
- Task: Few-shot Classification
- 实验
- 1、simple
- 2、medium
- 3、strong
- 4、boss
- 三、代码
- 模型构建准备工作
一、简介
任务对象是Omniglot数据集上的few-shot classification任务,内容是利用元学习找到好的初始化参数。
Task: Few-shot Classification
The Omniglot dataset
Omniglot数据集-背景集: 30个字母 -评估集: 20个字母
问题设置: 5-way 1-shot classification
Training MAML on Omniglot classification task.
Training / validation set:30 alphabets
- multiple characters in one alphabet
- 20 images for one character
Testing set:
640 support and query pairs - 5 support images
- 5 query images
实验
1、simple
简单的迁移学习模型
训练:对随机选择的5个任务进行正常的分类训练
验证和测试:对这5个支持图像进行微调,并对查询图像进行推理
2、medium
完成元学习内部和外部循环的TODO块,使用FO-MAML。设置solver = ‘meta’,epoch调节为120。FOMAML是MAML的简化版本,可节省训练时间,它忽略了内循环梯度对结果的影响。
# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values())
fast_weights = OrderedDict((name, param - inner_lr*grad)
for ((name, param), grad) in zip(fast_weights.items(), grads)
)
#raise NotImplementedError训练过程中需要设置该函数为损失函数
# TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()
#raise NotimplementedError
3、strong
使用MAML,可以计算更高阶的梯度,MAML就能用到内循环梯度的梯度 。
# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = OrderedDict((name, param - inner_lr*grad)
for ((name, param), grad) in zip(fast_weights.items(), grads)
)
#raise NotImplementedError训练过程中需要设置该函数为损失函数
4、boss
任务增强(通过元学习)-什么是合理的方法来创建新任务?
使用了task augmentation的方法来增加训练任务的变化性,有40%的可能性做augmentation,旋转90度或270度。
#MetaSolver函数中修改
for meta_batch in x:
# Get data
if torch.rand(1).item() > 0.6:
times = 1 if torch.rand(1).item() > 0.5 else 3
meta_batch = torch.rot90(meta_batch, times, [-1, -2])
三、代码
模型构建准备工作
由于我们的任务是图像分类,我们需要建立一个基于CNN的模型。但是,要实现MAML算法,我们需要调整“nn.Module”中的一些代码。在第10行,我们采用的梯度是代表原始模型参数(外环)的θ,而不是内环中的θ,因此我们需要使用functional_forward来计算输入图像的输出逻辑,而不是在nn.Module中使用forward。下面定义了这些功能。
def functional_forward(self, x, params):
for block in [1, 2, 3, 4]:
x = ConvBlockFunction(
x,
params[f"conv{block}.0.weight"],
params[f"conv{block}.0.bias"],
params.get(f"conv{block}.1.weight"),
params.get(f"conv{block}.1.bias"),
)
x = x.view(x.shape[0], -1)
x = F.linear(x, params["logits.weight"], params["logits.bias"])
return x
创建labels for 5-way 2-shot
def create_label(n_way, k_shot):
return torch.arange(n_way).repeat_interleave(k_shot).long()
# Try to create labels for 5-way 2-shot setting
create_label(5, 2)
计算精度
def calculate_accuracy(logits, labels):
"""utility function for accuracy calculation"""
acc = np.asarray(
[(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
).mean()
return acc
求解器首先从训练集中选择五个任务,然后对选择的五个任务进行正常的分类训练。在推理中,模型在支持集图像上对inner_train_step步骤进行微调,然后在查询集图像上进行推理。为了与元学习解算器保持一致,基本解算器具有与元学习解算器完全相同的输入和输出格式。
def BaseSolver(
model,
optimizer,
x,
n_way,
k_shot,
q_query,
loss_fn,
inner_train_step=1,
inner_lr=0.4,
train=True,
return_labels=False,
):
criterion, task_loss, task_acc = loss_fn, [], []
labels = []
for meta_batch in x:
# Get data
support_set = meta_batch[: n_way * k_shot]
query_set = meta_batch[n_way * k_shot :]
if train:
""" training loop """
# Use the support set to calculate loss
labels = create_label(n_way, k_shot).to(device)
logits = model.forward(support_set)
loss = criterion(logits, labels)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, labels))
else:
""" validation / testing loop """
# First update model with support set images for `inner_train_step` steps
fast_weights = OrderedDict(model.named_parameters())
for inner_step in range(inner_train_step):
# Simply training
train_label = create_label(n_way, k_shot).to(device)
logits = model.functional_forward(support_set, fast_weights)
loss = criterion(logits, train_label)
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
# Perform SGD
fast_weights = OrderedDict(
(name, param - inner_lr * grad)
for ((name, param), grad) in zip(fast_weights.items(), grads)
)
if not return_labels:
""" validation """
val_label = create_label(n_way, q_query).to(device)
logits = model.functional_forward(query_set, fast_weights)
loss = criterion(logits, val_label)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, val_label))
else:
""" testing """
logits = model.functional_forward(query_set, fast_weights)
labels.extend(torch.argmax(logits, -1).cpu().numpy())
if return_labels:
return labels
batch_loss = torch.stack(task_loss).mean()
task_acc = np.mean(task_acc)
if train:
# Update model
model.train()
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
return batch_loss, task_acc
元学习
def MetaSolver(
model,
optimizer,
x,
n_way,
k_shot,
q_query,
loss_fn,
inner_train_step=1,
inner_lr=0.4,
train=True,
return_labels=False
):
criterion, task_loss, task_acc = loss_fn, [], []
labels = []
for meta_batch in x:
# Get data
if torch.rand(1).item() > 0.6:
times = 1 if torch.rand(1).item() > 0.5 else 3
meta_batch = torch.rot90(meta_batch, times, [-1, -2])# B = rot90(A,k) 将数组 A 按逆时针方向旋转 k*90 度
support_set = meta_batch[: n_way * k_shot]
query_set = meta_batch[n_way * k_shot :]
# Copy the params for inner loop
fast_weights = OrderedDict(model.named_parameters())
### ---------- INNER TRAIN LOOP ---------- ###
for inner_step in range(inner_train_step):
# Simply training
train_label = create_label(n_way, k_shot).to(device)
logits = model.functional_forward(support_set, fast_weights)
loss = criterion(logits, train_label)
# Inner gradients update! vvvvvvvvvvvvvvvvvvvv #
""" Inner Loop Update """
# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = OrderedDict((name, param - inner_lr*grad)
for ((name, param), grad) in zip(fast_weights.items(), grads)
)
#raise NotImplementedError
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #
### ---------- INNER VALID LOOP ---------- ###
if not return_labels:
""" training / validation """
val_label = create_label(n_way, q_query).to(device)
# Collect gradients for outer loop
logits = model.functional_forward(query_set, fast_weights)
loss = criterion(logits, val_label)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, val_label))
else:
""" testing """
logits = model.functional_forward(query_set, fast_weights)
labels.extend(torch.argmax(logits, -1).cpu().numpy())
if return_labels:
return labels
# Update outer loop
model.train()
optimizer.zero_grad()
meta_batch_loss = torch.stack(task_loss).mean()
if train:
""" Outer Loop Update """
# TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()
#raise NotimplementedError
task_acc = np.mean(task_acc)
return meta_batch_loss, task_acc