李宏毅2023机器学习HW15-Few-shot Classification

news2024/9/24 23:22:07

文章目录

  • Link
  • Task: Few-shot Classification
  • Baseline
    • Simple—transfer learning
    • Medium — FO-MAML
    • Strong — MAML

Link

Kaggle

Task: Few-shot Classification

The Omniglot dataset

  • background set: 30 alphabets
  • evaluation set: 20 alphabets
  • Problem setup: 5-way 1-shot classification

Omniglot数据集

  • 背景集:30个字母
  • 评估集:20个字母
  • 问题设置:5-way 1-shot分类
    Definition of support set and query set

Baseline

Simple—transfer learning

直接把sample code运行即可

  • traing:
    对随机选择的5个任务进行正常分类训练验证/测试
  • validation / testing:
    对五个 Support Images 进行微调,并对Query Images进行推理

Slover首先从训练集中选择5个任务,然后对选择的5个任务进行正常分类训练。在推理中,模型在支持集support set图像上微调inner_train_step步骤,然后在查询集Query Set图像上进行推理。
为了与元学习Slover保持一致,基本Slover具有与元学习Slover完全相同的输入输出格式

def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())


            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )

            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)

    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    return batch_loss, task_acc

Medium — FO-MAML

FOMAML(First-Order MAML)是MAML(Model-Agnostic Meta-Learning)的一种简化版本。MAML是一种元学习算法,旨在通过训练模型使其能够在少量新数据上快速适应新任务。FOMAML通过忽略二阶导数来简化MAML的计算过程,从而提高计算效率。它在许多情况下表现良好,尤其是在计算资源有限的情况下。然而,它也可能在某些任务上表现不如完整的MAML。

MAML的核心思想是通过在多个任务上进行训练,使得模型能够在面对新任务时,只需少量数据就能快速收敛到一个好的参数配置。具体来说,MAML的训练过程包括两个层次的优化:

  • 内层优化(Inner Loop):在每个任务上进行少量的梯度更新,以适应该任务。

  • 外层优化(Outer Loop):在所有任务上进行梯度更新,以优化模型的初始参数,使得模型在面对新任务时能够快速适应。

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=False) # create_graph=False:这个参数表示在计算梯度时不创建计算图。在FOMAML中,我们只关心一阶导数,因此不需要创建计算图
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )


""" Outer Loop Update """
        # TODO: Finish the outer loop update
        # raise NotimplementedError
        meta_batch_loss.backward()
        optimizer.step()

Strong — MAML

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )


""" Outer Loop Update """
        # TODO: Finish the outer loop update
        # raise NotimplementedError
        meta_batch_loss.backward()
        optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2161759.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

9/24作业

1. 分文件编译 分什么要分文件编译? 防止主文件过大,不好修改,简化编译流程 1) 分那些文件 头文件:所有需要提前导入的库文件,函数声明 功能函数:所有功能函数的定义 主函数:main函数&…

请不要在TS中使用Function类型

在 TypeScript 中,避免使用 Function 作为类型。Function 代表的是“任意类型的函数”,这会带来类型安全问题。对于绝大多数情况,你可能更希望明确地指定函数的参数和返回值类型。 如果你确实想表达一个可以接收任意数量参数并返回任意类型的…

Kali wireshark抓包

wireshark 查看指定网卡进出流量 构造一个只能显示ICMP数据包的显示过滤器 ARP 同理,显示过滤器会屏蔽所有除了 ARP 请求和回复之外的数据包

力扣 中等 92.反转链表 II

文章目录 题目介绍题解 题目介绍 题解 class Solution {public ListNode reverseBetween(ListNode head, int left, int right) {// 创建一个哑节点,它的 next 指向头节点,方便处理ListNode dummy new ListNode(0, head);// p0 用于指向反转部分的前一个…

3. 轴指令(omron 机器自动化控制器)——>MC_MoveVelocity

机器自动化控制器——第三章 轴指令 6 MC_MoveVelocity变量▶输入变量▶输出变量▶输入输出变量 功能说明▶指令详情▶时序图▶重启运动指令▶多重启动运动指令▶异常 动作示例▶动作示例▶梯形图▶结构文本(ST) MC_MoveVelocity 使用伺服驱动器的位置控制模式,进行…

聊一下cookie,session,token的区别

cookie cookie是存放在客户端的,主要用于会话管理和用户数据保存;cookie通过http报文的请求头部分发送给服务器,服务器根据cookie就可以获取到里面携带的session id(用于获取服务器中对应的session数据),因为http是无状态协议,我们通常就是通过cookie去维护状态的 cookie是在…

Kali 联网

VMware 中分三种网络模式 桥接模式:默认余宿主机 VMnet0 绑定,像一台独立机 NAT 模式:默认余宿主机 VMnet8 绑定,需要通过物理机连接外网 仅主机模式:默认余宿主机 VMnet1 绑定,只能与物理机通信 VMware…

Linux系统容器化部署中,构建Docker 镜像中包含关键指令和参数的文件dockerfile的详细介绍

目录 一、Dockerfile的用处 1、自动化构建 2、可重复性 3、可移植性 4、版本控制 5、优化镜像大小 6、便于分享和分发 二、Dockerfile 的基本结构 1、基础镜像(FROM) 2、维护者信息(MAINTAINER/LABEL maintainer) 3、设置工作目…

C++之STL—List 链表

双向链表 链表的组成:链表由一系列**结点**组成 结点的组成:一个是存储数据元素的**数据域**,另一个是存储下一个结点地址的**指针域** STL中的链表是一个双向循环链表 构造函数 List 赋值和交换 容器大小操作 - 判断是否为空 --- empty - …

深度学习实战:UNet模型的训练与测试详解

🍑个人主页:Jupiter. 🚀 所属专栏:Linux从入门到进阶 欢迎大家点赞收藏评论😊 目录 1、云实例:配置选型与启动1.1 登录注册1.2 配置 SSH 密钥对1.3 创建实例1.4 登录云实例 2、云存储:数据集上传…

JavaScript --json格式字符串和对象的转化

json字符串解析成对象 &#xff1a; var obj JSON.parse(str) 对象转化成字符串&#xff1a;var str1 JSON.stringify(obj1) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Com…

第五篇:Linux进程的相关知识总结(1)

目录 第四章&#xff1a;进程 4.1进程管理 4.1.1进程管理需要的学习目标 4.1.1.1了解进程的相关信息 4.1.1.2僵尸进程的概念和处理方法&#xff1a; 4.1.1.3PID、PPID的概念以及特性&#xff1a; 4.1.1.4进程状态 4.1.2进程管理PS 4.1.2.1静态查看进程 4.1.2.1.1自定义…

搭建EMQX MQTT服务器并接入Home Assistant和.NET程序

本文主要介绍如何使用Docker搭建EMQX MQTT服务器&#xff0c;并将其接入到Home Assistant中&#xff0c;最后演示如何使用.NET接入MQTT。 1. 背景 在智能家居系统中&#xff0c;MQTT&#xff08;消息队列遥测传输协议&#xff09;是一种轻量级的消息传输协议&#xff0c;特别适…

《深度学习》—— 神经网络中的数据增强

文章目录 一、为什么要进行数据增强&#xff1f;二、常见的数据增强方法1. 几何变换2. 颜色变换3. 尺寸变换4. 填充5. 噪声添加6. 组合变换 三、代码实现四、注意事项五、总结 一、为什么要进行数据增强&#xff1f; 神经网络中的数据增强是一种通过增加训练数据的多样性和数量…

动态规划11,完全背包模板

NC309 完全背包 问题一&#xff1a;求这个背包至多能装多大价值的物品&#xff1f; 状态表示&#xff1a;经验题目要求 dp[i][j] 表示 从前i个物品中挑选&#xff0c;总体积不超过j&#xff0c;所有选法中&#xff0c;能选出来的最大价值。 状态转移方程 根据最后一步的状态&a…

vue2 搜索高亮关键字

界面&#xff1a; 搜索 “成功” 附上代码&#xff08;开箱即用&#xff09; <template><div class"box"><input class"input-box" v-model"searchKeyword" placeholder"输入搜索关键字" /><div class"r…

【深度】边缘计算神器之数据网关

分布式计算、云边协同、互联互通是边缘计算设备的三项重要特征。 边缘计算设备通过分布式计算模式&#xff0c;将数据处理和分析任务从中心化的云平台下放到设备网关&#xff0c;即更接近数据源的地方&#xff0c;从而显著降低了数据传输的延迟&#xff0c;提高了响应速度和处…

OpenCV normalize() 函数详解及用法示例

OpenCV的normalize函数用于对数组&#xff08;图像&#xff09;进行归一化处理&#xff0c;即将数组中的元素缩放到一个指定的范围或具有一个特定的标准&#xff08;如均值和标准差&#xff09;。它有两个原型函数, 如下: Normalize()规范化数组的范数或值范围。当normTypeNORM…

拾色器的取色的演示

前言 今天&#xff0c;有一个新新的程序员问我&#xff0c;如何确定图片中我们需要选定的颜色范围。一开始&#xff0c;我感到对这个问题很不屑。后来&#xff0c;想了想&#xff0c;还是对她说&#xff0c;你可以参考一下“拾色器”。 后来&#xff0c;我想关于拾色器&#…

C++ std::any升级为SafeAny

std::any测试 #include <any>class A { public:int8_t a; };int main(int argc, char* argv[]) {std::any num((int8_t)42);auto a std::any_cast<A>(num);return 0; }异常&#xff1a; 0x00007FFA9385CD29 处(位于 test.exe 中)有未经处理的异常: Microsoft C 异…