RIPGeo代码理解(五)utils.py( 辅助函数)第一部分

news2024/9/20 8:10:56

 

 代码链接:RIPGeo代码实现

├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数。

一、导入常用库和模块

from __future__ import print_function
import numpy as np
import torch
import warnings
import torch.nn as nn
import random
import matplotlib.pyplot as plt
import copy

这段代码首先包含一些导入语句,接着进行一些版本和警告的处理,最后导入了一些常用的库(numpytorchmatplotlib),并定义了一些常用的模块(nnplt)。

1、from __future__ import print_function:这是为了确保代码同时在Python 2和Python 3中都能正常运行。在Python 2中,print是一个语句,而在Python 3中,print()是一个函数。通过这个导入语句,可以在Python 2中使用Python 3风格的print函数。

2、import numpy as np:导入NumPy库,并用np作为别名。NumPy是一个用于科学计算的库,提供了数组等高性能数学运算工具。

3、import torch::导入PyTorch库。PyTorch是一个深度学习框架,提供了张量计算和神经网络搭建等功能。

4、import warnings:导入warnings模块,用于处理警告。

5、import torch.nn as nn:导入PyTorch中的神经网络模块。

6、import random:导入Python的random模块,用于生成伪随机数。

7、import matplotlib.pyplot as plt:导入matplotlib库的pyplot模块,用于绘制图表。

8、import copy:导入Python的copy模块,用于复制对象。

二、warnings.filterwarnings(action='once')

warnings.filterwarnings(action='once')

设置了在使用warnings.filterwarnings时的参数。filterwarnings函数用于配置警告过滤器,以控制哪些警告会被触发,以及如何处理这些警告。

具体来说,action='once'表示警告信息只会被显示一次。这对于一些可能会频繁触发的警告而言是一种控制方式,以避免在控制台或日志中大量重复的警告信息。在第一次触发警告时,它会被显示,但在后续的同类警告中,将不再显示。

请注意,这个配置仅适用于在warnings模块中配置的警告,它并不会影响其他类型的警告或错误。

三、DataPerturb()  数据扰动

class DataPerturb:
    def __init__(self, eta=1):
        self.eta = eta
        self.loss = torch.nn.MSELoss(reduction='sum')

    def perturb(self, model, data):
        # original
        lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay = data

        # obtain new graph representation
        _, ori_graph_feature = model(lm_X, lm_Y, tg_X,
                                     tg_Y, lm_delay,
                                     tg_delay)

        # add Gaussian data perturb
        new_lm_X, new_lm_Y, new_tg_X, new_tg_Y, new_lm_delay, new_tg_delay = lm_X.clone(), lm_Y.clone(), \
                                                                             tg_X.clone(), tg_Y.clone(), \
                                                                             lm_delay.clone(), tg_delay.clone()
        new_lm_X[:, -16:] += self.eta * torch.normal(0, torch.ones_like(new_lm_X[:, -16:]) * new_lm_X[:, -16:]).cuda()
        new_tg_X[:, -16:] += self.eta * torch.normal(0, torch.ones_like(new_tg_X[:, -16:]) * new_tg_X[:, -16:]).cuda()
        new_lm_delay += self.eta * torch.normal(0, torch.ones_like(new_lm_delay) * new_lm_delay).cuda()
        new_tg_delay += self.eta * torch.normal(0, torch.ones_like(new_tg_delay) * new_tg_delay).cuda()

        # obtain new graph representation
        _, new_graph_feature = model(new_lm_X, new_lm_Y, new_tg_X,
                                     new_tg_Y, new_lm_delay,
                                     new_tg_delay)

        data_loss = self.loss(ori_graph_feature, new_graph_feature)
        return data_loss

这段代码定义了一个名为 DataPerturb 的类,其目的是对给定的数据进行扰动,并计算扰动后的损失。

(一)__init__()

    def __init__(self, eta=1):
        self.eta = eta
        self.loss = torch.nn.MSELoss(reduction='sum')

__init__ 方法中,类初始化时可以指定一个参数 eta,默认为1。该参数用于控制扰动的强度。

损失函数使用MSELoss。

(二)perturb()

    def perturb(self, model, data):
        # original
        lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay = data

        # obtain new graph representation
        _, ori_graph_feature = model(lm_X, lm_Y, tg_X,
                                     tg_Y, lm_delay,
           

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1537515.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL】2.MySQL数据库的基本操作

目录 数据库基本操作 查看数据库信息 查看数据库结构 显示数据表的结构(字段) 常用的数据类型 数据库管理操作 SQL语句概述 SQL分类 1.DDL:数据定义语言 1.1创建数据库和表 创建数据库 创建数据表 1.2删除数据库和表 删除数据表…

2024年【化工自动化控制仪表】考试试卷及化工自动化控制仪表模拟考试题

题库来源:安全生产模拟考试一点通公众号小程序 化工自动化控制仪表考试试卷是安全生产模拟考试一点通总题库中生成的一套化工自动化控制仪表模拟考试题,安全生产模拟考试一点通上化工自动化控制仪表作业手机同步练习。2024年【化工自动化控制仪表】考试…

langchain+chatglm3+BGE+Faiss Linux环境安装依赖

前言 本篇默认读者已经看过之前windows版本,代码就不赘述,本次讲述是linux环境配置 超短代码实现!!基于langchainchatglm3BGEFaiss创建拥有自己知识库的大语言模型(准智能体)本人python版本3.11.0(windows环境篇&…

基于Gabor滤波器的指纹图像识别,Matlab实现

博主简介: 专注、专一于Matlab图像处理学习、交流,matlab图像代码代做/项目合作可以联系(QQ:3249726188) 个人主页:Matlab_ImagePro-CSDN博客 原则:代码均由本人编写完成,非中介,提供…

索尼下一代游戏主机PS5将于11月20日发售

索尼下一代游戏机PS5将于2020年11月20日发布。据悉,这款游戏机的售价可能会达到499美元(约合人民币3500元)。 我们知道游戏主机的价格低于游戏PC的价格。 既然PS5的主要硬件配置已经公开,那么现在配置一台同样配置的游戏PC需要多少…

从零开始学Spring Boot系列-集成Kafka

Kafka简介 Apache Kafka是一个开源的分布式流处理平台,由LinkedIn公司开发和维护,后来捐赠给了Apache软件基金会。Kafka主要用于构建实时数据管道和流应用。它类似于一个分布式、高吞吐量的发布-订阅消息系统,可以处理消费者网站的所有动作流…

全流程ArcGIS Pro技术应用

GIS是利用电子计算机及其外部设备,采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲,它是在一定的地域内,将地理空间信息和 一些与该地域地理信息相关的属性信息结合起来,达到对地理和属性信息的综合管理。GIS的…

探索AI+电商领域应用与发展

AI火的已经一塌糊涂了,已经有很大一部分的企业和个人已经坐上了这趟超音速列车,但对于电商领域具体都有哪些助理,目前为止还是比较散,今天来顺一下AIGC之与电商到底带来了些什么? 一、什么是AIGC AIGC是内容生产方式…

【LeetCode-74.搜索二维矩阵】

题目详情: 给你一个满足下述两条属性的 m x n 整数矩阵: 每行中的整数从左到右按非严格递增顺序排列。每行的第一个整数大于前一行的最后一个整数。 给你一个整数 target ,如果 target 在矩阵中,返回 true ;否则&am…

微服务day05(中) -- ES索引库操作

索引库就类似数据库表,mapping映射就类似表的结构。 我们要向es中存储数据,必须先创建“库”和“表”。 2.1.mapping映射属性 mapping是对索引库中文档的约束,常见的mapping属性包括: type:字段数据类型,…

[Linux]多线程(在Linux中的轻量级进程(LWP),怎么使用线程(接口))

目录 一、在Linux中的轻量级进程(LWP) 二、多线程的接口 1.创建线程(pthread_create) 2.线程ID(pthread_self) 3.线程终止 终止某个线程而不终止整个进程的三种方法: return pthread_…

高架学习笔记之系统分析与设计

目录 一、结构化方法(SASD) 1.1. 结构化分析方法(SA) 1.1.1. 数据流图(DFD) 1.1.2. 实体联系图(E-R图) 1.1.3. 状态转换图(STD) 1.1.4. 数据字典 1.2. 结构化设计方法&#x…

Python Flask框架 -- 加载静态文件

在项目中,一般都会把静态文件放在 static 目录下,如 images、css、js 等,html 放在 templates 目录下。 .py: from flask import Flask, render_templateapp Flask(__name__)app.route(/static) def static_demo():return rend…

初识C++(一)

目录 一、什么是C 二、关键字: 三、命名空间 : 1. C语言存在的问题: 2. namespace关键字: 3. 注意点: 4.使用命名空间分为三种: 四、输入输出: 五、缺省函数: 1. 什么是缺省…

2024年【山东省安全员C证】考试试卷及山东省安全员C证复审模拟考试

题库来源:安全生产模拟考试一点通公众号小程序 山东省安全员C证考试试卷是安全生产模拟考试一点通生成的,山东省安全员C证证模拟考试题库是根据山东省安全员C证最新版教材汇编出山东省安全员C证仿真模拟考试。2024年【山东省安全员C证】考试试卷及山东省…

《妈妈是什么》笔记(二) 让孩子自己做选择

经典摘录 孩子也会需要独立的空间做事情,求独立、求空间、求私隐 对于不管因为什么,别人在受到肯定和赞赏的时候,会对我们自己的心理带来因“比较”而产生的不适感甚至嫉妒感,进而在行为上影响了我们自己的节奏,产生一…

STL 容器元素减少但内存没有下降且不会自动释放,如何在运行时释放多余内存?【C++】

STL 容器元素减少但内存没有下降且不会自动释放,如何释放多余内存? 前言利用 swap 和匿名对象的性质进行收缩内存 前言 C程序里面我们经常会用到STL容器,容器在运行过程中可能会增长,导致它们分配的内存比实际存储的元素所需的内…

Linux源码包安装

目录 一、transmission源码包安装 二、 nginx源码包安装 一、transmission源码包安装 1、下载编译环境所需的软件包依赖 2、下载transmision源码包到用户主目录下 https://github.com/transmission/transmission/releases/download/4.0.5/transmission-4.0.5.tar.xz 3、解压…

python第三次项目作业

打印课堂上图案 判断一个数是否是质数(素数) 设计一个程序,完成(英雄)商品的购买(界面就是第一天打印的界面) 展示商品信息(折扣)->输入商品价格->输入购买数量->提示付款 输入付款金额->打印购买小票&a…

java 调用window操作系统文本转语音并生成播放文件

一、完整资源直接看这里&#xff1a; java调用window操作系统文本转语音并生成播放文件资源-CSDN文库 二、所需材料 材料一&#xff1a;最关键的&#xff0c;需要引用jacob包&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project x…