TrustGeo代码理解(一)main.py

news2024/9/22 19:43:53

代码链接:https://github.com/ICDM-UESTC/TrustGeo

一、导入各种模块和数据库

# -*- coding: utf-8 -*-
import torch.nn

from lib.utils import *
import argparse, os
import numpy as np
import random
from lib.model import *
import copy
from thop import profile
import pandas as pd

整体功能是准备运行一个 PyTorch 深度学习模型的环境,具体的功能实现需要查看 lib.utils、lib.model 中的代码,以及整个文件的后续部分。

1、# -*- coding: utf-8 -*-:指定脚本的字符编码为UTF-8。

2、import torch.nn:导入 PyTorch 的神经网络模块,用于定义和训练神经网络。

3、from lib.utils import *从 lib.utils 模块中导入所有内容,这可能包括一些工具函数或辅助函数,用于该脚本或项目的其他部分。

4、import argparse, os:导入 argparse 模块用于解析命令行参数,os 模块用于与操作系统交互。

5、import numpy as np:导入 NumPy 库,用于进行科学计算,特别是多维数组的处理。

6、import random:导入 random 模块,用于生成伪随机数。

7、from lib.model import *:从 lib.model 模块中导入所有内容,这可能包括定义神经网络模型的类等。

8、import copy:导入 copy 模块,用于复制对象,通常用于创建对象的深拷贝

9、from thop import profile从 thop 模块中导入 profile 函数,该函数用于计算 PyTorch 模型的 FLOPs(浮点运算数)和参数数量。(在代码链接中没有找到)

10、import pandas as pd:导入 Pandas 库,用于数据处理和分析,通常用于处理表格型数据。

二、参数初始化(通过命令行参数)

parser = argparse.ArgumentParser()
# parameters of initializing
parser.add_argument('--seed', type=int, default=2022, help='manual seed')
parser.add_argument('--model_name', type=str, default='TrustGeo')
parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"],
                    help='which dataset to use')

这部分代码的目的是通过命令行参数设置一些初始化的参数,例如随机数种子、模型名称和数据集名称。这使得在运行脚本时可以通过命令行参数来指定这些参数的值。

1、parser = argparse.ArgumentParser():创建一个 argparse.ArgumentParser 对象,用于解析命令行参数。

2、# parameters of initializing:注释,表示接下来是初始化参数的部分。

3、parser.add_argument('--seed', type=int, default=2022, help='manual seed'):添加一个命令行参数,名称为 '--seed',表示随机数种子,类型为整数,默认值为 2022,help 参数是在命令行中输入 --help 时显示的帮助信息。

4、parser.add_argument('--model_name', type=str, default='TrustGeo'):添加一个命令行参数,名称为 '--model_name',表示模型的名称,类型为字符串,默认值为 'TrustGeo'。

5、parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"], help='which dataset to use'):添加一个命令行参数,名称为 '--dataset',表示数据集的名称,类型为字符串,默认值为 'New_York',choices 参数指定了可选的值为 ["Shanghai", "New_York", "Los_Angeles"],用户只能从这三个值中选择。help 参数是在命令行中输入 --help 时显示的帮助信息。

三、训练过程参数设置

# parameters of training
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--lambda1', type=float, default=7e-3)
parser.add_argument('--lr', type=float, default=5e-3)
parser.add_argument('--harved_epoch', type=int, default=5) 
parser.add_argument('--early_stop_epoch', type=int, default=50)
parser.add_argument('--saved_epoch', type=int, default=200)  

这部分代码的目的是设置一些训练过程中的超参数,例如优化器的动量参数、学习率、权重参数等。这些参数在训练过程中会影响模型的更新和收敛速度。

1、# parameters of training:注释,表示接下来是训练参数的部分。

2、parser.add_argument('--beta1', type=float, default=0.9):添加一个命令行参数,名称为 '--beta1',表示 Adam 优化器的第一个动量(momentum)参数,类型为浮点数,默认值为 0.9。

3、parser.add_argument('--beta2', type=float, default=0.999):添加一个命令行参数,名称为 '--beta2',表示 Adam 优化器的第二个动量参数,类型为浮点数,默认值为 0.999。

4、parser.add_argument('--lambda1', type=float, default=7e-3):添加一个命令行参数,名称为 '--lambda1',表示某个权重参数,类型为浮点数,默认值为 7e-3。

5、parser.add_argument('--lr', type=float, default=5e-3):添加一个命令行参数,名称为 '--lr',表示学习率,类型为浮点数,默认值为 5e-3。

6、parser.add_argument('--harved_epoch', type=int, default=5):添加一个命令行参数,名称为 '--harved_epoch',表示某个 epoch 的值,类型为整数,默认值为 5。

7、parser.add_argument('--early_stop_epoch', type=int, default=50):添加一个命令行参数,名称为 '--early_stop_epoch',表示早停(early stop)的 epoch 数,类型为整数,默认值为 50。

8、parser.add_argument('--saved_epoch', type=int, default=200):  添加一个命令行参数,名称为 '--saved_epoch',表示保存模型的 epoch 数,类型为整数,默认值为 200。

四、模型参数设置

# parameters of model
parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else")

opt = parser.parse_args()
print("Learning rate: ", opt.lr)
print("Dataset: ", opt.dataset)

这部分代码的目的是解析命令行参数,并打印出学习率和数据集名称。--dim_in 参数用于指定输入维度,可以选择是 51 或者 30。

1、# parameters of model:注释,表示接下来是训模型参数的部分。

2、parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else"):添加一个命令行参数,名称为 '--dim_in',表示输入的维度,类型为整数,默认值为 30。choices 参数指定了可选的值为 [51, 30],用户只能从这两个值中选择。help 参数是在命令行中输入 --help 时显示的帮助信息。

3、opt = parser.parse_args():使用 argparse 解析命令行参数,将结果存储在 opt 变量中

4、print("Learning rate: ", opt.lr):打印学习率,即 opt 对象中的 lr 属性

5、print("Dataset: ", opt.dataset):打印数据集名称,即 opt 对象中的 dataset 属性

五、设置随机种子数

if opt.seed:
    print("Random Seed: ", opt.seed)
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
torch.set_printoptions(threshold=float('inf'))

这一部分的目的是确保在使用随机数的场景中,每次运行程序得到的随机结果是可复现的。通过设置相同的随机数种子,可以使得每次运行得到相同的随机数序列。

1、如果 opt 对象中的 seed 属性存在(不为 0 或 False 等假值),则执行以下操作:

  • 打印随机数种子的信息。
  • 使用 random 模块设置 Python 内建的随机数生成器的种子。
  • 使用 PyTorch 的 torch 模块设置随机数种子。

2、torch.set_printoptions(threshold=float('inf')):设置 PyTorch 的打印选项,将打印的元素数量限制设置为无穷大,即不限制打印的元素数量。这样可以确保在打印张量时,所有元素都会被打印出来,而不会被省略。

六、过滤所有警告信息

warnings.filterwarnings('ignore')

过滤掉所有警告信息,将警告信息忽略。这通常用于在代码中避免显示一些不影响程序执行的警告信息,以保持输出的清晰。在某些情况下,警告信息可能是有用的,但如果明确知道这些警告对程序执行没有影响,可以选择忽略它们。

七、动态选择运行环境

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device:", device)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

这部分代码的目的是根据硬件环境动态选择运行模型的设备,并选择相应的 PyTorch 张量类型。如果有可用的 GPU,就使用 GPU 运行模型和 GPU 张量类型;否则,使用 CPU 运行模型和 CPU 张量类型。

1、device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'):创建一个 PyTorch 设备对象,表示运行模型的设备。如果 CUDA 可用(即有可用的 GPU),则使用 'cuda:0' 表示第一个 GPU,否则使用 'cpu' 表示 CPU。

2、print("device:", device):打印设备的信息,即使用的是 GPU 还是 CPU。

3、cuda = True if torch.cuda.is_available() else False:根据 CUDA 是否可用设置一个布尔值,表示是否使用 GPU。如果 CUDA 可用,则 cuda

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文本聚类——文本相似度(聚类算法基本概念)

一、文本相似度 1. 度量指标: 两个文本对象之间的相似度两个文本集合之间的相似度文本对象与集合之间的相似度 2. 样本间的相似度 基于距离的度量: 欧氏距离 曼哈顿距离 切比雪夫距离 闵可夫斯基距离 马氏距离 杰卡德距离 基于夹角余弦的度量 公式…

【从零开始学习JVM | 第七篇】深入了解 堆回收

前言: Java堆作为内存管理中最核心的一部分,承担着对象实例的存储和管理任务。堆内存的高效使用对于保障程序的性能和稳定性至关重要。因此,深入理解Java堆回收的原理、机制和优化策略,对于Java开发人员具有重要的意义。 本文旨在…

竞赛保研 python图像检索系统设计与实现

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 python图像检索系统设计与实现 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:3分创新点:4分 该项目较为新颖&#xff0c…

Python基础04-数据容器

零、文章目录 Python基础04-数据容器 1、了解字符串 &#xff08;1&#xff09;字符串的定义 字符串是 Python 中最常用的数据类型。我们一般使用引号来创建字符串。创建字符串很简单&#xff0c;只要为变量分配一个值即可。<class ‘str’>即为字符串类型。一对引号…

计算机毕业设计 基于SpringBoot的二手物品交易管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

【一文带你掌握Java中方法定义、调用和重载的技巧】

方法的定义和调用 方法的定义 方法&#xff08;method&#xff09;是一段用于实现特定功能的代码块&#xff0c;类似于其他编程语言中的函数&#xff08;function&#xff09;。方法被用来定义类或类的实例的行为特征和功能实现。方法是类和对象行为特征的抽象表示。方法与面向…

『PyTorch』张量和函数之gather()函数

文章目录 PyTorch中的选择函数gather()函数 参考文献 PyTorch中的选择函数 gather()函数 import torch a torch.arange(1, 16).reshape(5, 3) """ result: a [[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12],[13, 14, 15]] """# 定义两个index…

圆通速递查询,圆通速递单号查询,一键复制查询好的物流信息

批量查询圆通速递单号的物流信息&#xff0c;并将查询好的物流信息一键复制出来。 所需工具&#xff1a; 一个【快递批量查询高手】软件 圆通速递单号若干 操作步骤&#xff1a; 步骤1&#xff1a;运行【快递批量查询高手】软件&#xff0c;第一次使用的朋友记得先注册&…

002.Java实现两数相加

题意 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示两数之和的新链表。 示例 输入&#xff1a;l1[2,4,3],l2[5,6,4] 输出…

ChatGPT4 Excel 高级复杂函数案例实践

案例需求: 需求中需要判断多个条件进行操作。 可以让ChatGPT来实现这样的操作。 Prompt:有一个表格B2单元格为入职日期,C2单元格为员工等级(A,B,C),D2单元格为满意度分数(1,2,3,4,5)请给入职一年以上,员工等级为A级并且满意度在3分以上的人发4000元奖金,给入…

普冉(PUYA)单片机开发笔记(10): I2C通信-配置从机

概述 I2C 常用在某些型号的传感器和 MCU 的连接&#xff0c;速率要求不高&#xff0c;距离很短&#xff0c;使用简便。 I2C的通信基础知识请参见《基础通信协议之 IIC详细讲解 - 知乎》。 PY32F003 可以复用出一个 I2C 接口&#xff08;PA3&#xff1a;SCL&#xff0c;PA2&a…

计算机组成原理-函数调用的汇编表示(call和ret指令 访问栈帧 切换栈帧 传递参数和返回值)

文章目录 call指令和ret指令高级语言的函数调用x86汇编语言的函数调用call ret指令小结其他问题 如何访问栈帧函数调用栈在内存中的位置标记栈帧范围&#xff1a;EBP ESP寄存器访问栈帧数据&#xff1a;push pop指令访问栈帧数据&#xff1a;mov指令小结 如何切换栈帧函数返回时…

深度学习记录--矩阵维数

如何识别矩阵的维数 如下图 矩阵的行列数容易在前向和后向传播过程中弄错&#xff0c;故写这篇文章来提醒易错点 顺便起到日后查表改错的作用 本文仅作本人查询参考(摘自吴恩达深度学习笔记)

Flink系列之:SQL提示

Flink系列之&#xff1a;SQL提示 一、动态表选项二、语法三、例子四、查询提示五、句法六、加入提示七、播送八、随机散列九、随机合并十、嵌套循环十一、LOOKUP十二、进一步说明十三、故障排除十四、连接提示中的冲突案例十五、什么是查询块 SQL 提示可以与 SQL 语句一起使用来…

MyBatis Plus 大数据量查询优化

大数据量操作的场景大致如下&#xff1a; 数据迁移 数据导出 批量处理数据 在实际工作中当指定查询数据过大时&#xff0c;我们一般使用分页查询的方式一页一页的将数据放到内存处理。但有些情况不需要分页的方式查询数据或分很大一页查询数据时&#xff0c;如果一下子将数…

【单元测试】Junit 4--junit4 内置Rule

1.0 Rules ​ Rules允许非常灵活地添加或重新定义一个测试类中每个测试方法的行为。测试人员可以重复使用或扩展下面提供的Rules之一&#xff0c;或编写自己的Rules。 1.1 TestName ​ TestName Rule使当前的测试名称在测试方法中可用。用于在测试执行过程中获取测试方法名称…

MATLAB 计算两片点云间的最小距离(2种方法) (39)

MATLAB 计算两片点云间的最小距离 (39) 一、算法介绍二、算法实现1.常规计算方法2.基于KD树的快速计算一、算法介绍 假设我们现在有两片点云 1 和 2 ,需要计算二者之间的最小距离,这里提供两种计算方法,分别是常规计算和基于KD树近邻搜索的快速计算方法,使用的测试数据如…

Java 分布式框架 —— Dubbo 快速入门

1 分布式系统中的相关概念 1.1 大型互联网项目架构目标 传统项目和互联网项目 传统项目&#xff1a;例如 OA、HR、CRM 等&#xff0c;服务对象为&#xff1a;企业员工 互联网项目&#xff1a;天猫、微信、百度等&#xff0c;服务对象为&#xff1a;全体网民 互联网项目特点…

2 使用postman进行接口测试

上一篇&#xff1a;1 接口测试介绍-CSDN博客 拿到开发提供的接口文档后&#xff0c;结合需求文档开始做接口测试用例设计&#xff0c;下面用最常见也最简单的注册功能介绍整个流程。 说明&#xff1a;以演示接口测试流程为主&#xff0c;不对演示功能做详细的测试&#xff0c;…

JVM学习之JVM概述

JVM的整体结构 Hotspot VM是目前市面上高性能虚拟机代表作之一 它采用解释器与即时编译器并存的架构 在今天&#xff0c;Java程序的运行性能已经达到了可以和C/C程序一较高下的地步 Java代码执行流程 具体图为 JVM架构模型 Java编译器输入的指令流基本上是一种基于 栈的指令…