机器学习HW15元学习

news2025/1/15 22:39:21

文章目录

  • 一、简介
    • Task: Few-shot Classification
  • 实验
    • 1、simple
    • 2、medium
    • 3、strong
    • 4、boss
  • 三、代码
      • 模型构建准备工作


一、简介

任务对象是Omniglot数据集上的few-shot classification任务,内容是利用元学习找到好的初始化参数。

Task: Few-shot Classification

The Omniglot dataset
在这里插入图片描述
Omniglot数据集-背景集: 30个字母 -评估集: 20个字母
问题设置: 5-way 1-shot classification
在这里插入图片描述
Training MAML on Omniglot classification task.
在这里插入图片描述
Training / validation set:30 alphabets

  • multiple characters in one alphabet
  • 20 images for one character
    在这里插入图片描述
    Testing set:
    640 support and query pairs
  • 5 support images
  • 5 query images
    在这里插入图片描述

实验

1、simple

简单的迁移学习模型
训练:对随机选择的5个任务进行正常的分类训练
验证和测试:对这5个支持图像进行微调,并对查询图像进行推理
在这里插入图片描述

在这里插入图片描述

2、medium

完成元学习内部和外部循环的TODO块,使用FO-MAML。设置solver = ‘meta’,epoch调节为120。FOMAML是MAML的简化版本,可节省训练时间,它忽略了内循环梯度对结果的影响。

# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values())
fast_weights = OrderedDict((name, param - inner_lr*grad)
      for ((name, param), grad) in zip(fast_weights.items(), grads)
      )    
#raise NotImplementedError训练过程中需要设置该函数为损失函数


# TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()
#raise NotimplementedError

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

3、strong

使用MAML,可以计算更高阶的梯度,MAML就能用到内循环梯度的梯度 。

# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = OrderedDict((name, param - inner_lr*grad)      
    for ((name, param), grad) in zip(fast_weights.items(), grads)
    )    
#raise NotImplementedError训练过程中需要设置该函数为损失函数

在这里插入图片描述
在这里插入图片描述

4、boss

任务增强(通过元学习)-什么是合理的方法来创建新任务?
使用了task augmentation的方法来增加训练任务的变化性,有40%的可能性做augmentation,旋转90度或270度。
在这里插入图片描述

#MetaSolver函数中修改
for meta_batch in x:
    # Get data
    if torch.rand(1).item() > 0.6:
        times = 1 if torch.rand(1).item() > 0.5 else 3
        meta_batch = torch.rot90(meta_batch, times, [-1, -2])

在这里插入图片描述
在这里插入图片描述

三、代码

模型构建准备工作

由于我们的任务是图像分类,我们需要建立一个基于CNN的模型。但是,要实现MAML算法,我们需要调整“nn.Module”中的一些代码。在第10行,我们采用的梯度是代表原始模型参数(外环)的θ,而不是内环中的θ,因此我们需要使用functional_forward来计算输入图像的输出逻辑,而不是在nn.Module中使用forward。下面定义了这些功能。

def functional_forward(self, x, params):
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

创建labels for 5-way 2-shot

def create_label(n_way, k_shot):
    return torch.arange(n_way).repeat_interleave(k_shot).long()


# Try to create labels for 5-way 2-shot setting
create_label(5, 2)

计算精度

def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc

求解器首先从训练集中选择五个任务,然后对选择的五个任务进行正常的分类训练。在推理中,模型在支持集图像上对inner_train_step步骤进行微调,然后在查询集图像上进行推理。为了与元学习解算器保持一致,基本解算器具有与元学习解算器完全相同的输入和输出格式。

def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())


            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )

            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)

    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    return batch_loss, task_acc

元学习

def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        if torch.rand(1).item() > 0.6:
            times = 1 if torch.rand(1).item() > 0.5 else 3
            meta_batch = torch.rot90(meta_batch, times, [-1, -2])#  B = rot90(A,k) 将数组 A 按逆时针方向旋转 k*90 度
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())

        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            # Simply training
            train_label = create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss = criterion(logits, train_label)
            # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #
            """ Inner Loop Update """
            # TODO: Finish the inner loop update rule
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            fast_weights = OrderedDict((name, param - inner_lr*grad)
                                        for ((name, param), grad) in zip(fast_weights.items(), grads)
                                        )
            
            #raise NotImplementedError
            # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #

        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:
            """ training / validation """
            val_label = create_label(n_way, q_query).to(device)

            # Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss = criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))
        else:
            """ testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    # Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        """ Outer Loop Update """
        # TODO: Finish the outer loop update
        meta_batch_loss.backward()
        optimizer.step()
        #raise NotimplementedError

    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156098.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在VSCode中使用Compaq Visual Fortran编译运行Frotran程序

本片文章主要是为了使用VSCode编译运行带QuickWin的老版本Fortran代码。 一、准备工作 安装VSCode和Compaq Visual Fortran6.6。 二、配置Fortran工程 用VSCode打开保存有Frotran代码的文件夹 建立.vscode文件夹,建立launch.json和task.json文件,分…

二、TCP/IP---Ethernet和IP协议

TCP/ip协议栈 OSI模型TCP/IP协议栈应用层,表示层,会话层应用层传输层主机到主机层(传输层)网络层网络层数据链路层,物理层网络接入层 Ethernet协议 以太网,实现链路层的数据传输和地址封装(MA…

马蹄集 三角形坐标

三角形坐标 难度&#xff1a;青铜 ○时间限制&#xff1a;1秒 巴占用内存&#xff1a;64M 输入三角形三个顶点A,B,C的坐标(x,y),根据公式计算并输出三 角形面积。 S1/2*X1y2X2y33y1-X1y3-X2y1-x3y2 #include <bits/stdc.h> using namespace std; int main(){double x[4],…

Win10应用商店无法加载错误0x80072F7D怎么办?

Win10应用商店无法加载错误0x80072F7D怎么办&#xff1f;有用户开启电脑的Win10软件商店想要获取软件的时候&#xff0c;发现软件页面无法进行正常的加载&#xff0c;里面的内容显示为错误代码0x80072F7D。那么这个情况怎么去进行解决呢&#xff1f;一起来看看详细的解决方法分…

PMP证书到期了,有必要续吗?

我觉得续证是有需要的&#xff0c;毕竟证书有用的地方很多。 下面我们将从两方面分享&#xff1a; 1. PMP 证书在国内的含金量怎么样&#xff1f; 2. HR 如何看待 PMP 证书&#xff1f; 说到 PMP 证书的含金量&#xff0c;相信这个问题是所有人都关心的。对于如何来评判 PMP…

达芬奇密码题解

题目信息达芬奇隐藏在蒙娜丽莎中的数字列:1 233 3 2584 1346269 144 5 196418 21 1597 610 377 10946 89 514229 987 8 55 6765 2178309 121393 317811 46368 4181 1 832040 2 28657 75025 34 13 17711 记录在达芬奇窗台口的神秘数字串:36968853882116725547342176952286这道题…

Vue3——第十章(异步组件:defineAsyncComponent)

一、defineAsyncComponent基本使用 在大型项目中&#xff0c;我们可能需要拆分应用为更小的块&#xff0c;并仅在需要时再从服务器加载相关组件。Vue 提供了 defineAsyncComponent 方法来实现此功能&#xff1a; 如你所见&#xff0c;defineAsyncComponent 方法接收一个返回 P…

文档流code案例小汇【处理高度塌陷】

clear mdn文档 clear只是清除浮动&#xff0c;不能缓解【高度塌陷】 https://developer.mozilla.org/zh-CN/docs/Web/CSS/clear 值 none 元素不会被向下移动以清除浮动。left 元素被向下移动以清除左浮动。right 元素被向下移动以清除右浮动。both 元素被向下移动以清除左右浮动…

LeetCode刷题模版:61 - 70

目录 简介61. 旋转链表62. 不同路径63. 不同路径 II64. 最小路径和65. 有效数字【未理解】66. 加一67. 二进制求和68. 文本左右对齐【未实现】69. x 的平方根70. 爬楼梯结语简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~ ଘ(੭ˊᵕˋ)੭ 昵称:…

1月第1周榜单丨B站UP主排行榜(飞瓜数据B站)发布!

飞瓜轻数发布2023年1月2日-1月8日飞瓜数据UP主排行榜&#xff08;B站平台&#xff09;&#xff0c;通过充电数、涨粉数、成长指数三个维度来体现UP主账号成长的情况&#xff0c;为用户提供B站号综合价值的数据参考&#xff0c;根据UP主成长情况用户能够快速找到运营能力强的B站…

Web(四)

基本对象&#xff1a;1. Function&#xff1a;函数(方法)对象创建&#xff1a;1. var fun new Function(形式参数列表,方法体); //忘掉吧2. function 方法名称(形式参数列表){方法体}3. var 方法名 function(形式参数列表){方法体}既是方法也是对象&#xff0c;不传或者参数不…

PyG Temporal搭建STGCN实现多变量输入多变量输出时间序列预测

目录I. 前言II. STGCNIII. PyG TemporalIV. 模型训练/测试V. 代码I. 前言 前面已经写过不少时间序列预测的文章&#xff1a; 深入理解PyTorch中LSTM的输入和输出&#xff08;从input输入到Linear输出&#xff09;PyTorch搭建LSTM实现时间序列预测&#xff08;负荷预测&#x…

C++——类和对象1

目录 1. 类和对象认识 2. 类的引入 3. 类的定义 4. 类的访问限定符及封装 4.1 访问限定符 4.2 封装 5. 类的作用域 6. 类的实例化 7. 类对象模型 7.1 如何计算类对象的大小 7.2 类对象的存储方式猜测 7.3 结构体内存对齐规则 8. this指针 8.1 this指针的…

cv-cuda (cvcuda、nvcv)教程——Python安装

由于当前版本安装后&#xff0c;大家反应import nvcv cvcuda 失败&#xff0c;看官方文档&#xff0c;当前还不是很规范&#xff0c;特此记录当前版本的安装方法。 官方安装文档&#xff1a;Installation — CV-CUDA Alpha documentation 方法一、如果你有权限推荐deb安装方式…

机器学习第15章-规则学习

机器学习第15章-规则学习 以下列出我觉得重要&#xff0c;在编码的思路中可以参考的地方 冲突消融 当一条规则的判断出现不同的结果时&#xff0c;解决冲突的方法 1.投票法 2.排序法 3.无规则法 序贯覆盖 生成规则过程中去除当前规则所能覆盖的数据 生成方式 自顶向下…

双软认证的好处,赶紧来看看吧

1、“双软件”认可对企业有什么好处&#xff1f; 对于认定的软件企业&#xff0c;从盈利年度起&#xff0c;第一年和第二年免征企业所得税&#xff0c;第三年至第五年减半征收企业所得税&#xff0c;即两免三减。对认定软件产品的企业&#xff0c;对实际增值税负担超过3%的部分…

【ONE·C++ || vector (二)】

总言 主要讲述vector的模拟实现。 文章目录总言1、基本框架搭建&#xff1a;成员变量2、对构造函数、析构函数3、增删查改、空间扩容3.1、vector::push_back、vector::pop_back3.2、vector::reserve、vector::capacity、vector::size3.3、operator[ ]3.4、遍历&#xff1a;迭代…

记录robosense RS-LIDAR-16使用过程1

拿到设备&#xff0c;首先对照型号去官网下载相关资料&#xff08;用户手册/软件/SDK&#xff09;,需要填写资料https://www.robosense.ai/resources-27工业相机通常也有出厂SDK文件&#xff0c;之前有使用知微传感的D130相机&#xff0c;也是先安装SDK、看手册然后使用。大型厂…

【Java集合】Map接口常用方法及实现子类

文章目录01 Map 接口实现类的特点02 Map 接口和常用方法03 Map 接口遍历方法04 HashMap 用例 小结05 HashMap 底层&扩容机制06 Hashtable07 PropertiesMap为双列集合&#xff0c;Set集合的底层也是Map&#xff0c;只不过有一列是常量所占&#xff0c;只使用到了一列。 01 …

国科大《高级人工智能》沈老师部分——行为主义笔记

国科大《高级人工智能》沈老师部分——行为主义笔记 沈华伟老师yyds&#xff0c;每次上他的课都有一种深入浅出的感觉&#xff0c;他能够把很难的东西讲的很简单&#xff0c;听完就是醍醐灌顶&#xff0c;理解起来特别清晰今年考试题目这部分跟往年基本一样&#xff0c;沈老师画…