原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!由于工作量大,准备整8个系列完事,-----系列5

news2024/7/4 4:30:19

在这里插入图片描述

文章目录

  • 前言
  • 一、原始程序---计算原型,开始训练,计算损失
  • 二、每一行代码的详细解释
    • 2.1 粗略分析
    • 2.2 每一行代码详细分析


前言

承接系列4,此部分属于原型类中的计算原型,开始训练,计算损失函数。


一、原始程序—计算原型,开始训练,计算损失

def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data))
			data = self.model(data)[0]	#将查询点嵌入另一个空间
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#网络的训练
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			self.center[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#优化器
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()
	
	def loss(self,sample):	#自定义loss
		loss_1 = autograd.Variable(torch.FloatTensor([0]))
		for i in range(self.Nc):
			query_dataSet = sample['xq'][i]
			for n in range(self.Nq):
				data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
				data = Variable(torch.from_numpy(data))
				data = self.model(data)[0]	#将查询点嵌入另一个空间
				#查询点与每个中心点逐个计算欧氏距离
				predict = 0
				for j in range(self.Nc):
					center_j = sample['xc'][j]
					if j == 0:
						predict = eucli_tensor(data,center_j)
					else:
						predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
				#为loss叠加
				loss_1 += -1*F.log_softmax(predict,dim=0)[i]
		loss_1 /= self.Nq*self.Nc
		return loss_1

二、每一行代码的详细解释

2.1 粗略分析

第一个函数 compute_center(self,data_set) 用于计算支持集中心点的坐标。输入参数 data_set 是一个 numpy 对象,代表支持集。该函数中用了一个 for 循环遍历了每一个支持集中的样本,将其嵌入到另一个空间,并计算其总和来求得所有样本的中心点。最后返回计算出的中心点的坐标。

第二个函数 train(self,labels_data,class_number) 是网络的训练函数。其中 labels_data 是标签数据,class_number 是类别数。首先从 class_number 中随机选取出 Nc 个类,对于每个选出来的类,从其标签数据 D_set 中随机选取出支持集和查询集,并将支持集传入 compute_center() 函数计算中心点。接着将计算出的中心点和查询集存储在样本字典 sample 中。最后使用 Adam 优化器对模型进行优化,并计算损失(调用了 loss 函数),将反向传播得到的梯度更新到模型中。

第三个函数def loss(self,sample)是一个自定义的损失函数,它的作用是计算样本的损失值。在这个损失函数中,使用了欧氏距离和softmax函数。

2.2 每一行代码详细分析

def compute_center(self,data_set): - 这是一个方法,用于计算给定数据集(支持集)的中心点。

2-4. center = 0 - 初始化中心点的变量为0。

5-8. for i in range(self.Ns): - 遍历数据集中的每个数据点。

9-14. 这部分代码将数据集中的每个数据点重塑为适应模型输入的形状,并将其转换为PyTorch的Variable。然后,使用模型将查询点嵌入另一个空间。

if i == 0: - 如果这是第一个数据点,则将查询点设置为中心点。

16-19. 否则,将查询点添加到中心点。

center /= self.Ns - 计算中心点,这是所有数据点的平均值。

return center - 返回计算得到的中心点。

接下来是 train 方法:

23-24. 从给定的标签数据中选择类别索引并随机洗牌。选择特定数量的类别(self.Nc)。

25-30. 对于所选类别中的每一个,从其数据中随机选择支持集和查询集。

31-33. 使用 compute_center 方法计算每个类的中心点,并将其存储在列表中。同时将查询集也存储在列表中。

34-37. 初始化优化器,这里使用Adam优化算法,学习率设置为0.001。然后清空梯度缓存。

38-42. 计算损失函数值,该损失函数是根据自定义的损失函数计算的。然后进行反向传播以计算梯度。

optimizer.step() - 使用优化器更新模型的参数。

最后是自定义的损失函数 loss

45-46. 初始化一个张量 loss_1 为0,它用于累计损失值。

47-52. 对于每个类别(self.Nc),遍历查询集中的每个数据点。对于每个查询点,将其嵌入到另一个空间中,并计算它与每个中心点之间的欧氏距离。

53-57. 将所有的距离组合在一起,并使用softmax函数将其转换为概率值。然后,对于每个查询点,累加其与所有中心点的负对数似然损失值。

loss_1 /= self.Nq*self.Nc - 将损失值除以查询集中的数据点数量和类别数量以获得平均损失值。

return loss_1 - 返回计算得到的损失值。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1224083.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IO流-序列化流

一,序列化(把java对象写到对象中去) 二, Object OutputStream(对象字节输出流) 三,案例 package BigDecimal;import java.io.FileOutputStream; import java.io.ObjectOutputStream;public class Main {public static…

upload-labs(1-17关攻略详解)

upload-labs pass-1 上传一个php文件,发现不行 但是这回显是个前端显示,直接禁用js然后上传 f12禁用 再次上传,成功 右键打开该图像 即为位置,使用蚁剑连接 连接成功 pass-2 源码 $is_upload false; $msg null; if (isse…

Springboot集成JDBC

1&#xff0c;pom.xml配置jar包 <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-jdbc</artifactId> </dependency> 2&#xff0c;配置数据源信息 server:port: 8088spring:datasource:dr…

Java-整合OSS

文章目录 前言一、OSS 简介二、OSS 的使用1. Bucket 的创建与文件上传2. 创建 RAM 与用户授权3. 图形化管理工具-ossbrowser 三、Java 整合 OSS1. 基本实现2. 客户端直传 前言 最近公司的技术负责人让我整合下 OSS 到项目中&#xff0c;所以花了一点时间研究了下 OSS&#xff…

Docker入门学习笔记

学习笔记网址推送&#xff1a;wDocker 10分钟快速入门_哔哩哔哩_bilibili docker是用来解决什么问题的&#xff1f; 例如当你在本地主机写了个web应用&#xff0c;而你打算将该应用发送给其他客户端进行案例测试和运行&#xff0c;若是传统做法&#xff0c;就比较复杂&#xf…

十个一手app拉新地推拉新推广接单平台,放单/接任务渠道

做过地推拉新的朋友一定都非常清楚&#xff0c;app拉新推广一手接单平台&#xff0c;和非一手接任务平台之间的收益差&#xff0c;可以用天壤之别来形容。那么一手app拉新渠道应该怎么找&#xff1f;下面这十个常见的地推拉新app接单平台&#xff0c;一定要收藏。 1. 聚量推客…

TCP协议相关实验

文章目录 一.TCP相关实验1.理解CLOSE_WAIT状态2.理解TIME_WAIT状态3.解决TIME_WAIT状态引起的bind失败的方法4.理解listen的第二个参数5.使用Wireshark分析TCP通信流程 二.TCP与UDP1.TCP与UDP对比2.用UDP实现可靠传输&#xff08;经典面试题&#xff09; 一.TCP相关实验 1.理解…

C++模版初阶

泛型编程 如下的交换函数中&#xff0c;它们只有类型的不同&#xff0c;应该怎么实现一个通用的交换函数呢&#xff1f; void Swap(int& left, int& right) {int temp left;left right;right temp; }void Swap(double& left, double& right) {double temp…

大模型重塑软件设计,南京真我加入飞桨技术伙伴,大模型生态圈成员又添一员!...

为帮助伙伴更快、更好的应用大模型技术&#xff0c;飞桨技术伙伴体系及权益基于星河共创计划全面升级&#xff0c;通过丰富的场景、技术、算力、品牌等资源&#xff0c;为伙伴企业提供一站式的大模型资源对接&#xff0c;全面降低创建AI原生应用的门槛。 近日&#xff0c;南京真…

数据同步策略解读

前言 我们都知道在大多数情况下&#xff0c;通过浏览器查询到的数据都是缓存数据&#xff0c;如果缓存数据与数据库的数据存在较大差异的话&#xff0c;可能会产生比较严重的后果的。对此&#xff0c;我们应该也必须保证数据库数据、缓存数据的一致性&#xff0c;也就是就是缓…

Swagger(3):Swagger入门案例

1 编写SpringBoot项目 新建一个Rest请求控制器。 package com.example.demo.controller;import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.Reques…

Linux下查看pytorch运行时真正调用的cuda版本

一般情况我们会安装使用多个cuda版本。而且pytorch在安装时也会自动安装一个对应的版本。 正确查看方式&#xff1a; 想要查看 Pytorch 实际使用的运行时的 cuda 目录&#xff0c;可以直接输出 cpp_extension.py 中的 CUDA_HOME 变量。 import torch import torch.utils imp…

​软考-高级-系统架构设计师教程(清华第2版)【第13章 层次式架构设计理论与实践(P466~495)-思维导图】​

软考-高级-系统架构设计师教程&#xff08;清华第2版&#xff09;【第13章 层次式架构设计理论与实践&#xff08;P466~495&#xff09;-思维导图】 课本里章节里所有蓝色字体的思维导图

原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列6 (承接系列5)

文章目录 一、原始代码---随机采样和评估模型二、详细解释分析每一行代码 一、原始代码—随机采样和评估模型 def randomSample(self,D_set): #从D_set随机取支持集和查询集&#xff08;20个类中的其中一个类&#xff0c;shape为[20,105,105]&#xff09;index_list list(ran…

算法设计与分析 | 分治棋盘

题目 在一个2^k * 2^k个方格组成的棋盘中&#xff0c;恰有一个方格与其他方格不同&#xff0c;称该方格为一特殊方格&#xff0c;且称该棋盘为一特殊棋盘。在棋盘覆盖问题中&#xff0c;要用图示的4种不同形态的L型骨牌覆盖给定的特殊棋盘上除特殊方格以外的所有方格&#xff0…

2023亚太杯数学建模思路 - 案例:异常检测

文章目录 赛题思路一、简介 -- 关于异常检测异常检测监督学习 二、异常检测算法2. 箱线图分析3. 基于距离/密度4. 基于划分思想 建模资料 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 一、简介 – 关于异常…

【Linux】进程间通信 -- 管道

对于进程间通信的理解 首先&#xff0c;进程间通信的本质是&#xff0c;让不同的进程看到同一份资源&#xff08;这份资源不能隶属于任何一个进程&#xff0c;即应该是共享的&#xff09;。而进程间通信的目的是为了实现多进程之间的协同。 但由于进程运行具有独立性&#xff…

stable diffusion十七种controlnet详细使用方法总结

个人网站&#xff1a;https://tianfeng.space 前言 最近不知道发点什么&#xff0c;做个controlnet 使用方法总结好了&#xff0c;如果你们对所有controlnet用法&#xff0c;可能了解但是有点模糊&#xff0c;希望能对你们有用。 一、SD controlnet 我统一下其他参数&#…

python 对图像进行聚类分析

import cv2 import numpy as np from sklearn.cluster import KMeans import time# 中文路径读取 def cv_imread(filePath, cv2_falgcv2.COLOR_BGR2RGB): cv_img cv2.imdecode(np.fromfile(filePath, dtypenp.uint8), cv2_falg) return cv_img# 自定义装饰器计算时间 def…

解决:虚拟机远程连接失败

问题 使用FinalShell远程连接虚拟机的时候连接不上 发现 虚拟机用的VMware&#xff0c;Linux发行版是CentOs 7&#xff0c;发现在虚拟机中使用ping www.baidu.com是成功的&#xff0c;但是使用FinalShell远程连接不上虚拟机&#xff0c;本地网络也ping不通虚拟机&#xff0c…