原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

news2025/1/23 11:24:07

在这里插入图片描述

文章目录

  • 原型网络进行分类的基本流程
  • 一、原始代码---计算欧氏距离,设计原型网络(计算原型+开始训练)
  • 二、每一行代码的详细解释
  • 总结


原型网络进行分类的基本流程

利用原型网络进行分类,基本流程如下:

1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。

一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)

def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算
	return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)

class Protonets(object):
	def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):
		#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型
		self.input_shape = input_shape
		self.outDim = outDim
		self.batchSize = 1
		self.Ns = Ns
		self.Nq = Nq
		self.Nc = Nc
		if trainval == False:
			#若训练一个新的模型,初始化CNN和中心点
			self.center = {}
			self.model = CNNnet(input_shape,outDim)
		else:
			#否则加载CNN模型和中心点
			self.center = {}
			self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''
			self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''
	
	def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data))
			data = self.model(data)[0]	#将查询点嵌入另一个空间
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#网络的训练
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			self.center[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#优化器
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()

二、每一行代码的详细解释

def eucli_tensor(x, y):
    return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)

这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1) 将结果转换成一个形状为 (1,) 的张量。

class Protonets(object):
    def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):
        self.input_shape = input_shape
        self.outDim = outDim
        self.batchSize = 1
        self.Ns = Ns
        self.Nq = Nq
        self.Nc = Nc
        if trainval == False:
            self.center = {}
            self.model = CNNnet(input_shape, outDim)
        else:
            self.center = {}
            self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')
            self.load_center(log_data + 'model_center_' + str(step) + '.csv')

这是一个 Protonets 类的定义,它有一个构造函数 __init__,用于初始化类的属性。其中的参数含义如下:

  • input_shape:输入数据的形状。
  • outDim:输出维度。
  • Ns:支持集(support set)的数量。
  • Nq:查询集(query set)的数量。
  • Nc:每次迭代所选类别数。
  • log_data:模型和中心的存储位置。
  • step:训练的步数。
  • trainval:是否重新开始训练模型。

根据 trainval 的取值,分为两种情况进行初始化:

  1. trainval=False:表示训练一个新的模型。此时,初始化一个空的中心字典 self.center,并创建一个名为 CNNnet 的模型对象 self.model,其输入形状为 input_shape,输出维度为 outDim
  2. trainval=True:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典 self.center。然后通过 torch.load 加载之前训练保存的模型文件 log_data + 'model_net_' + str(step) + '.pkl',并将其赋给 self.model。接着调用 load_center 方法加载之前训练保存的中心文件 log_data + 'model_center_' + str(step) + '.csv'

总结

这段代码是一个用于实现 Protonets 算法的类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1223607.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

csapp深入理解计算机系统 bomb lab(1)phase_1

实验目的:进一步了解机器级代码,提高汇编语言、调试器和逆向工程等方面原理与技能的掌握。 实验环境:C、linux 实验获取:进入csapp官网,点击linux/x86-64 binary bomb下载实验压缩包。 实验说明:一共有6…

剑指JUC原理-19.线程安全集合

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…

TensorFlow:GPU的使用

**引言** TensorFlow 是一个由 Google 开发的开源机器学习框架,它提供了丰富的工具和库,支持开发者构建和训练各种深度学习模型。而 GPU 作为一种高性能并行计算设备,能够显著提升训练深度学习模型的速度,从而加快模型迭代和优化…

logistic回归后快速绘制亚组森林图!SCI发表级高清图片分分钟生成!

本周为大家重点介绍一下风暴统计平台的最新板块——亚组森林图! 现在亚组分析好像越来越流行,无论是观察性研究还是RCT研究,亚组分析一般配备森林图。 比如这张图: 还有这个: 森林图不仅是画图的画法,背后还…

[C国演义] 哈希的使用和开闭散列的模拟实现

哈希的使用和开闭散列的模拟实现 1. 使用1.1 unordered_map的接口1.2 unordered_set的接口 2. 哈希底层2.1 概念2.2 解决哈希冲突 3. 实现3.1 开放寻址法3.2 拉链法 1. 使用 1.1 unordered_map的接口 构造 void test1() {// 空的unordered_map对象unordered_map<int, in…

智能配电系统解决方案

智能配电系统解决方案是一种集成了先进技术和智能化功能的配电系统&#xff0c;它能够提高电力系统的效率、可靠性和安全性。力安科技智能配电系统解决方案依托电易云-智慧电力物联网&#xff0c;具体实施的方案如下&#xff1a; 智能化设备和传感器&#xff1a;采用智能化的开…

安全框架springSecurity+Jwt+Vue-1(vue环境搭建、动态路由、动态标签页)

一、安装vue环境&#xff0c;并新建Vue项目 ①&#xff1a;安装node.js 官网(https://nodejs.org/zh-cn/) 2.安装完成之后检查下版本信息&#xff1a; ②&#xff1a;创建vue项目 1.接下来&#xff0c;我们安装vue的环境 # 安装淘宝npm npm install -g cnpm --registryhttps:/…

招聘小程序源码 人才招聘网源码

招聘小程序源码 人才招聘网源码 求职招聘小程序源码系统是一种基于微信小程序的招聘平台&#xff0c;它可以帮助企业和求职者快速、方便地进行招聘和求职操作。 该系统通常包括以下功能模块&#xff1a; 用户注册和登录&#xff1a;用户可以通过微信小程序注册和登录&#…

H5ke11--1登录界面一直保存--用本地localStorage存储

目录 代码详解 localStage优点 :一直保存着 注意事项: storage属性们 代码详解 ke8学校陈老师H5-CSDN博客文章浏览阅读76次。实现H5中新增的三个元素&#xff1a;forEach的使用方法。https://blog.csdn.net/m0_72735063/article/details/134019012即此之后 当然可以分为按…

Linux inotify 文件监控

Linux 内核 2.6.13 以后&#xff0c;引入了 inotify 文件系统监控功能&#xff0c;通过 inotify 可以对敏感目录设置事件监听。这样的功能被也被包装成了一个文件监控神器 inotify-tools。 使用 inotify 进行文件监控的过程&#xff1a; 创建 inotify 实例&#xff0c;获取 i…

【从入门到起飞】JavaSE—IO流(1)字节输入流字符输出流

&#x1f38a;专栏【JavaSE】 &#x1f354;喜欢的诗句&#xff1a;天行健&#xff0c;君子以自强不息。 &#x1f386;音乐分享【如愿】 &#x1f384;欢迎并且感谢大家指出小吉的问题&#x1f970; 文章目录 &#x1f33a;概述&#x1f33a;作用&#x1f33a;分类&#x1f33…

如何去开发一个springboot starter

如何去开发一个springboot starter 我们在平时用 Java 开发的时候&#xff0c;在 pom.xml 文件中引入一个依赖就可以很方便的使用了&#xff0c;但是你们知道这是如何实现的吗。 现在我们就来解决这一个问题&#xff01; 创建 SpringBoot 项目 首先我们要做的就是把你想要给别…

Wireshark TS | 应用传输缓慢问题

问题背景 沿用之前文章的开头说明&#xff0c;应用传输慢是一种比较常见的问题&#xff0c;慢在哪&#xff0c;为什么慢&#xff0c;有时候光从网络数据包分析方面很难回答的一清二楚&#xff0c;毕竟不同的技术方向专业性太强&#xff0c;全栈大佬只能仰望&#xff0c;而我们…

【Spring篇】使用注解进行开发

&#x1f38a;专栏【Spring】 &#x1f354;喜欢的诗句&#xff1a;更喜岷山千里雪 三军过后尽开颜。 &#x1f386;音乐分享【如愿】 &#x1f970;欢迎并且感谢大家指出小吉的问题 文章目录 &#x1f33a;原代码&#xff08;无注解&#xff09;&#x1f384;加上注解⭐两个注…

20231117在ubuntu20.04下使用ZIP命令压缩文件夹

20231117在ubuntu20.04下使用ZIP命令压缩文件夹 2023/11/17 17:01 百度搜索&#xff1a;Ubuntu zip 压缩 https://blog.51cto.com/u_64214/7641253 Ubuntu压缩文件夹zip命令 原创 chenglei1208 2023-09-28 17:21:58博主文章分类&#xff1a;LINUX 小工具 文章标签命令行压缩包U…

打不开github网页解决方法

问题&#xff1a; 1、composer更新包总是失败 2、github打不开&#xff0c;访问不了 解决方法&#xff1a;下载一个Watt Toolkit工具&#xff0c;勾选上&#xff0c;一键加速就可以打开了。 下载步骤&#xff1a; 1、打开网址&#xff1a; Watt Toolkit 2、点击【下载wind…

Python (十一) 迭代器与生成器

迭代器 迭代器是访问集合元素的一种方式&#xff0c;可以记住遍历的位置的对象 迭代器有两个基本的方法&#xff1a;iter() 和 next() 字符串&#xff0c;列表或元组对象都可用于创建迭代器 字符串迭代 str1 Python str_iter iter(str1) print(next(str_iter)) print(next(st…

原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列2

文章目录 一、原始代码二、每一行代码的详细解释 一、原始代码 labels_trainData ,labels_testData load_data() wide labels_trainData[0][0].shape[0] length labels_trainData[0][0].shape[1] for label in labels_trainData.keys():labels_trainData[label] np.reshap…

FastJsonAPI

maven项目 pom.xml <dependencies><dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>2.0.26</version></dependency><dependency><groupId>junit</groupId>&l…

vmware17 虚拟机拷贝、备份、复制使用

可以在虚拟机运行的情况下进行拷贝 查看新安装的虚拟机位置 跳转到上一级目录 复制虚拟机 复制虚拟机整个目录 删除lck文件&#xff0c;不然开机的时候会报错 用vmware 打开新复制的虚拟机 lck文件全部删除 点击开机 开机成功