原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列8

news2025/1/13 3:31:04

在这里插入图片描述

文章目录

  • 前言
  • 一、原始代码
  • 二、对每一行代码的解释:
  • 总结


前言

这是该系列原型网络的最后一段代码及其详细解释,感谢各位的阅读!


一、原始代码

if __name__ == '__main__':
    ##载入数据
    labels_trainData, labels_testData = load_data()  # labels_trainData是字典,是key:value形式
    class_number_train = max(list(labels_trainData.keys())) #963
    class_number_test = max(list(labels_testData.keys())) #658

    wide = labels_trainData[0][0].shape[0]  # 105      #二维张量,shape[0]代表行数,shape[1]代表列数
    length = labels_trainData[0][0].shape[1]  # 105

    for label in labels_trainData.keys():
        labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])

    for label in labels_testData.keys():
        labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])

    ##初始化模型
    protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)  # '''根据需求修改类的初始化参数,参数含义见protonets_net.py'''

    ##训练prototypical_network
    for n in range(100):  ##随机选取x个类进行一个episode的训练
        protonets.train(labels_trainData, class_number_train)
        if n % 2 == 0 and n != 0:  # 每5次存储一次模型,并测试模型的准确率,训练集的准确率和测试集的准确率被存储在model_step_eval.txt中
            torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
            protonets.save_center('./log/model_center_' + str(n) + '.csv')
            test_accury = protonets.evaluation_model(labels_testData, class_number_test)
            print(test_accury)
            str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'
            with open('./log/model_step_eval.txt', "a") as f:
                f.write(str_data)
        print(n)

二、对每一行代码的解释:

  1. if __name__ == '__main__':
    这是一个Python的惯用写法,表示当脚本直接被运行时(而不是被作为模块导入时),才会执行下面的代码块。

  2. labels_trainData, labels_testData = load_data()
    调用 load_data() 函数加载数据,并将返回的标签训练数据和标签测试数据保存到 labels_trainDatalabels_testData 变量中。

  3. class_number_train = max(list(labels_trainData.keys()))
    获取标签训练数据中的最大键(即最大类别数),并将其保存到 class_number_train 变量中。

  4. class_number_test = max(list(labels_testData.keys()))
    获取标签测试数据中的最大键(即最大类别数),并将其保存到 class_number_test 变量中。

  5. wide = labels_trainData[0][0].shape[0]
    获取标签训练数据中第一个样本的宽度,并将其保存到 wide 变量中。

  6. length = labels_trainData[0][0].shape[1]
    获取标签训练数据中第一个样本的长度,并将其保存到 length 变量中。

  7. for label in labels_trainData.keys():
    遍历标签训练数据中的所有键。

  8. labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])
    对每个标签训练数据进行重塑,将其形状改为 [-1, 1, wide, length],其中 -1 表示自动计算维度大小。

  9. for label in labels_testData.keys():
    遍历标签测试数据中的所有键。

  10. labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])
    对每个标签测试数据进行重塑,将其形状改为 [-1, 1, wide, length]

  11. protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)
    创建一个 Protonets 类的实例,传入模型的初始化参数。

  12. for n in range(100):
    从0到99的循环中,执行以下代码块。

  13. protonets.train(labels_trainData, class_number_train)
    调用 protonets 实例的 train() 方法进行模型训练,传入标签训练数据和类别数。

  14. if n % 2 == 0 and n != 0:
    如果 n 是偶数且不为0,则执行以下代码块。

  15. torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
    保存模型到 './log/model_net_' + str(n) + '.pkl' 的文件路径。

  16. protonets.save_center('./log/model_center_' + str(n) + '.csv')
    调用 protonets 实例的 save_center() 方法,将模型的中心点保存到 './log/model_center_' + str(n) + '.csv'

  17. test_accury = protonets.evaluation_model(labels_testData, class_number_test)
    调用 protonets 实例的 evaluation_model() 方法,对模型进行评估并返回测试准确率,将其保存到 test_accury 变量中。

  18. print(test_accury)
    打印测试准确率。

  19. str_data = str(n) + ',' + str(' test_accury ') + str(test_accury) + '\n'
    构建一个字符串以保存到文件中。

  20. with open('./log/model_step_eval.txt', "a") as f:
    打开一个文件,以追加模式写入。


总结

原型网络(Prototypical Network)是一种用于小样本学习的模型,由Jake Snell等人于2017年提出。它是一种基于元学习(meta-learning)的方法,主要用于解决在具有少量标记样本的情况下进行分类任务的问题。

传统的深度学习模型在处理小样本学习时通常表现不佳,因为它们需要大量的标记样本来进行训练。然而,在现实世界中,我们往往只有少量标记样本可用。原型网络通过引入一个用于表示类别的中心向量(原型)的概念,解决了这个问题。

原型网络的功能和优势如下:

  1. 小样本学习:原型网络适用于具有少量标记样本的分类任务,可以在只有几个样本可用时进行准确的分类。

  2. 元学习能力:原型网络通过学习类别的原型向量,能够在遇到新类别时进行快速学习,从而实现元学习的目标。

  3. 欧氏距离度量:原型网络使用欧氏距离来度量样本与原型之间的相似性,从而进行分类推断。这种度量方式非常直观和可解释,使得模型更易于理解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1225986.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常见树种(贵州省):001松类

摘要:本专栏树种介绍图片来源于PPBC中国植物图像库(下附网址),本文整理仅做交流学习使用,同时便于查找,如有侵权请联系删除。 图片网址:PPBC中国植物图像库——最大的植物分类图片库 一、华山松…

麻将馆电脑计费系统,棋牌室怎么用电脑控制灯计时,佳易王计时计费系统软件下载

麻将馆电脑计费系统,棋牌室怎么用电脑控制灯计时,佳易王计时计费系统软件下 棋牌室电脑灯控系统,需要安装一个灯控器,软件发出开灯和关灯的指令,相应的灯就打开或关闭。在点击开始计时的时候,开灯&#xff…

黔院长 | 为什么要调经络?原来通经络对人体健康如此重要!

人体的组成较为复杂,在外有皮肤、毛发;在内有经络、五脏;其他还有我们看不到的精气、津液等等,也因此人体会生各种各样的疾病。 为什么说经络畅通对人体健康如此重要?身体内外始终是一个统一的整体,内外之间…

vue3基础学习

##以前怎么玩的? ###MVC Model:Bean View:视图 Controller ##vue的ref reactive ref:必须是简单类型 reactive:必须不能是简单类型 ###创建一个Vue项目 npm init vuelatest ###生命周期 ###setup相关 ####Vue2的一些写法 -- options API ####Vue3的写法 组合式API Vu…

Python —— Mock接口测试

前言 今天跟小伙伴们一起来学习一下如何编写Python脚本进行mock测试。 什么是mock? 测试桩,模拟被测对象的返回,用于测试 通常意义的mock指的就是mock server, 模拟服务端返回的接口数据,用于前端开发,第三方接口联调 为什么…

数据结构与算法-哈夫曼树与图

🌞 “永远积极向上,永远豪情满怀,永远热泪盈眶!” 哈夫曼树与图 🎈1.哈夫曼树🔭1.1树与二叉树的转换🔭1.2森林与二叉树的转换🔭1.3哈夫曼树🔎1.3.1哈夫曼树的概念&#x…

大数据HCIE成神之路之数学(2)——线性代数

线性代数 1.1 线性代数内容介绍1.1.1 线性代数介绍1.1.2 代码实现介绍 1.2 线性代数实现1.2.1 reshape运算1.2.2 转置实现1.2.3 矩阵乘法实现1.2.4 矩阵对应运算1.2.5 逆矩阵实现1.2.6 特征值与特征向量1.2.7 求行列式1.2.8 奇异值分解实现1.2.9 线性方程组求解 1.1 线性代数内…

基于STM32单片机数字电压表自动切换量程及源程序

一、系统方案 1、本设计采用这STM32单片机作为主控器。 2、液晶1602显示。 3、内部ADC采集电压0-12V,自动切换档位。 二、硬件设计 原理图如下: 三、单片机软件设计 1、首先是系统初始化 u8 i; u16 a,b,c,d; u16 adcx; float adc; unsigned char datas…

AIGC实战 - 使用变分自编码器生成面部图像

AIGC实战 - 使用变分自编码器生成面部图像 0. 前言1. 数据集分析2. 训练变分自编码器2.1 变分自编码器架构2.2 变分自编码器分析 3. 生成新的面部图像4. 潜空间算术5. 人脸变换小结系列链接 0. 前言 在自编码器和变分自编码器上,我们都仅使用具有两个维度的潜空间。…

Alien Skin Exposure2024免费版图片颜色滤镜插件

Alien Skin Exposure一款非常专业的图片后期处理软件,内含500多种照片滤镜。是一款图片后期处理功能非常强大的软件。这款软件可以对图片的后期效果做很好的处理。 打开Alien Skin Exposure软件,会显示下面这个界面,如图1. ExposureX8win-安…

爱奇艺大数据离在线混部

混部作为一种提高资源利用率、降低成本的的方案,被业界普遍认可。爱奇艺在云原生化与降本增效的过程中,成功将大数据离线计算、音视频内容处理等工作负载与在线业务进行了混部,并且取得了阶段性收益。本文重点以大数据为例,介绍从…

图解API设计风格,看一眼就忘不了了!

点击下方“JavaEdge”,选择“设为星标” 第一时间关注技术干货! 免责声明~ 任何文章不要过度深思! 万事万物都经不起审视,因为世上没有同样的成长环境,也没有同样的认知水平,更「没有适用于所有人的解决方案…

反弹Shell

概述 反弹shell(reverse shell)就是控制端监听在某TCP/UDP端口,被控端发起请求到该端口,并将其命令行的输入输出转到控制端。reverse shell与telnet,ssh等标准shell对应,本质上是网络概念的客户端与服务端…

原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列7(承接系列6)

文章目录 前言一、原始代码---保存原型点,加载原型点二、代码逐行解释 前言 此部分为原型网络的两个函数,分别为保存原型点函数和加载原型点函数,与之前的系列相承接。 一、原始代码—保存原型点,加载原型点 def save_center(self,path):datas []for …

C/C++ 获取主机网卡MAC地址

MAC地址(Media Access Control address),又称为物理地址或硬件地址,是网络适配器(网卡)在制造时被分配的全球唯一的48位地址。这个地址是数据链路层(OSI模型的第二层)的一部分&#…

端口号大揭秘:网络世界的“门牌号”有多牛?

大家好,今天我们来聊一聊网络中的端口号。如果你以为端口号只是冷冰冰的数字,那你就大错特错了。端口号,这些看似枯燥的数字背后,隐藏着一个个生动的故事。 目录 大家好,今天我们来聊一聊网络中的端口号。如果你以为端…

odoo17前端js框架的演化

odoo17发布了,从界面上看,变化还是很明显的,比16更漂亮了,本来以为源码不会发生太大的变化,结果仔细一瞧,变化也不小。 1、打包好的文件数量和大小发生了变化 打包好的文件从两个变成了一个,在…

适用于全部安卓手机的 5 大免费 Android 数据恢复

您是否面临这样一种情况,即在Android设备上丢失了一些重要文件,但不知道应该选择哪种Android数据恢复来取回它们? 如果您以前从未备份过Android数据,则很难解决问题。 本文将介绍排名前5位的免费Android数据恢复软件。 您可以获…

同花顺,通达信,东方财富股票竞价,早盘板块、概念、题材竞价数据接口

早盘板块、概念、题材竞价数据接口 量化接口地址:https://stockapi.com.cn 通过分析每天早盘的板块竞价,从而判断出今日主力资金的看好方向 地址: https://stockapi.com.cn/v1/base/bkjjzq?tradeDate2023-11-08再结合个股竞价数据筛选出自…

科大讯飞会议笔记本、GoodNotes、E人E本 功能及体验对比

科大讯飞会议笔记本、GoodNotes、E人E本功能及体验对比 【旧文档,怕失传】 通过对科大讯飞会议笔记本、基于iPad的GoodNotes以及E人E本的各项功能指标进行了实际对比,得出了以下结果: 在实际体验中,科大讯飞笔记本在录音方面表…