Pytorch学习笔记1:张量+训练参数传入与处理+制作训练集

news2024/12/24 2:35:58

文章目录

  • Pytorch中张量的一些常见函数
    • 最基础也最常见的方法
    • 关于Indexing, Slicing, Joining, Mutating Ops(索引、切片、聚合、旋转)
    • 随机种子
      • torch.bernoulli(input)
      • torch.normal
      • torch.rand(size)
      • torch.randn(size)
      • torch.randperm(n)
  • Python--argparse--命令行选项、参数和子命令解析器
    • 创建一个解析器
    • 添加参数
    • 解析参数
    • ArgumentParser对象
    • add_argument()方法
    • parse_args()方法
  • Pytorch中Dataset与DataLoader的使用

Pytorch中张量的一些常见函数

最基础也最常见的方法

  • numpy array与Torch Tensor共享内存,更改其中一个的数值,另一个的数值也会更改。

  • 使用torch.函数(size)时,size可以有三种形式:

    1. tuple元组(a,b,)
    2. list列表[a,b,c]
    3. torch.Size
  • array与tensor的区别: tensor可以是n维的;矩阵是二维数组。

  • torch.zeros/ones(size,默认dtype=None或dtype=torch.int32/float32/float64)类型设置可选,默认会根据torch.set_default_tensor_type(type/string)这个API的全局设置,来初始化默认数据类型。

  • torch.numel(input) 计算input张量中的元素数量

  • torch.arange(默认start=0,end,默认step=1)生成一维张量,生成的数值[start, end),一维张量的size为(end-start)/step向上取整,默认数据类型为int32
    在这里插入图片描述

  • torch.range(默认start=0,end,默认step=1)也生成一维张量,其size为(end-start)/step向下取整+1,默认数据类型为float32
    在这里插入图片描述

for i in torch.arange(600):
    print("epoch:", i)
  • torch.eye生成二维张量,对角线为1、其余元素为0,torch.eye(n,m),m不写则为n维方阵
  • torch.full生成特定size的、特定填充值的张量,torch.full(size,full_value)

关于Indexing, Slicing, Joining, Mutating Ops(索引、切片、聚合、旋转)

  • torch.cat(tensors, dim=0/1),tensors是张量的列表,以相同的维度拼接张量。
  • torch.chunk(input,chunk,默认dim=0)将张量分割成chunk个张量,每个张量都是输入张量的特殊的视角;如果输入张量的维度不能整除的话,最后一个张量看起来比较小。
  • torch.gather(input, dim, index)会沿着某一个维,取一些变量,得到输出张量。在输入张量的dim维度上,根据索引取值,其余维度不变。
  • torch.reshape(input, shape)改变输入张量的维度,且其数值以及相对顺序不变,当shape为(-1,)时,将输入张量拉成一维张量。注:(-1)为整数,(-1,)为元组
  • torch.scatter_(dim,index,src)从src张量的所有元素中,根据dim与index索引指明的位置,将src对应位置上的元素写入当前张量中。注:_表示分配的内存不变,原位改变
  • torch.split(tensor, split_size_or_sections, dim=0)这个split比chunk更常用,因为split函数的size参数可以输入一个列表,输入张量被分割成N个输出张量,N=列表长度。
  • torch.squeeze(input, dim=None, out=None)将大小为1的维度抹掉,在很多模型中最后会做一个全连接,全连接层会把某些维度映射为1,再使用squeeze把大小为1的维度抹掉。
  • torch.stack(tensors,默认dim=0)沿着某一个新的维度,来将几个张量拼接起来,注意所有的张量要有同样的大小。与cat不同,stack将张量堆叠起来,维度增加
>>> b
tensor([[0.5181, 0.0923],
        [0.1382, 0.6238],
        [0.2272, 0.2959]])
>>> b.shape
torch.Size([3, 2])
>>> a=torch.rand([3,2])
>>> a
tensor([[0.2284, 0.5338],
        [0.1582, 0.4309],
        [0.8541, 0.3610]])
>>> torch.stack([a,b]).shape
torch.Size([2, 3, 2])
>>> torch.stack([a,b])
tensor([[[0.2284, 0.5338],
         [0.1582, 0.4309],
         [0.8541, 0.3610]],

        [[0.5181, 0.0923],
         [0.1382, 0.6238],
         [0.2272, 0.2959]]])
>>> torch.stack([a,b],dim=1)
tensor([[[0.2284, 0.5338],
         [0.5181, 0.0923]],

        [[0.1582, 0.4309],
         [0.1382, 0.6238]],

        [[0.8541, 0.3610],
         [0.2272, 0.2959]]])
>>> torch.stack([a,b],dim=1).shape
torch.Size([3, 2, 2])
  • torch.take(input,index)将输入张量一律看成一维张量,根据索引index取对应数值,得到输出张量。
  • torch.tile(input,dims)将输入张量在指定维度复制,在dims参数对应的维度上取N则复制N遍。如果dims参数中指定的维度小于输入张量的维度,则在指定维度的前面所有维度补1。因为1即代表着在对应维度上不复制。a_tiled=torch.tile(a,[1,2])表示复制一列a。
  • torch.transpose(input,dim0,dim1)转置,dim0和dim1是两个即将要交换的维度
  • torch.unbind(input,默认dim=0)降维,在指定维度返回所有切片的元组,每个切片都比原来降低一个维度
  • torch.unsqueeze(input,dim)在指定维度增维,增加的维度的大小为1。注:dim为-1代表最后一个维度
  • torch.where(条件语句,a,b) 如果条件成立返回a,条件不成立返回b。a和b为相同形状的张量

随机种子

神经网络参数需要随机初始化,需要生成随机数,利用种子生成随机数,并将种子固定,则相当于从同样的分布中生成随机数, 模型可复现。torch.manual_seed(seed).代码如果调用了numpy,那么numpy的种子也要固定。

torch.bernoulli(input)

在这里插入图片描述
输入张量每个位置的数值为该位置生成1的概率

torch.normal

第一种方法,torch.normal(mean=a,b, std=a,b),注意mean和std的个数要相同,该方法生成很多个数值,每个数值由对应的mean和std来重新生成一次。
也可以,共享均值,不同方差;共享均值和方差,生成多个。

torch.rand(size)

[0,1)均匀分布

torch.randn(size)

均值为0,标准差为1的正态分布中

torch.randperm(n)

对0到n-1之间的数进行随机组合

Python–argparse–命令行选项、参数和子命令解析器

神经网络中常见的一种定义参数的方式:

parser = argparse.ArgumentParser(description='PyTorch Time series forecasting')
parser.add_argument('--data', type=str, required=True,
                    help='location of the data file')
parser.add_argument('--model', type=str, default='LSTNet',
                    help='')
parser.add_argument('--hidCNN', type=int, default=100,
                    help='number of CNN hidden units')
parser.add_argument('--hidRNN', type=int, default=100,
                    help='number of RNN hidden units')
parser.add_argument('--window', type=int, default=24 * 7,
                    help='window size')
parser.add_argument('--CNN_kernel', type=int, default=6,
                    help='the kernel size of the CNN layers')
parser.add_argument('--highway_window', type=int, default=24,
                    help='The window size of the highway component')
parser.add_argument('--clip', type=float, default=10.,
                    help='gradient clipping')
parser.add_argument('--epochs', type=int, default=100,
                    help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                    help='batch size')
parser.add_argument('--dropout', type=float, default=0.2,
                    help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--seed', type=int, default=54321,
                    help='random seed')
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--log_interval', type=int, default=2000, metavar='N',
                    help='report interval')
parser.add_argument('--save', type=str,  default='model/model.pt',
                    help='path to save the final model')
parser.add_argument('--cuda', type=str, default=True)
parser.add_argument('--optim', type=str, default='adam')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--horizon', type=int, default=12)
parser.add_argument('--skip', type=float, default=24)
parser.add_argument('--hidSkip', type=int, default=5)
parser.add_argument('--L1Loss', type=bool, default=True)
parser.add_argument('--normalize', type=int, default=2)
parser.add_argument('--output_fun', type=str, default='sigmoid')

# 识别由用户填写的一些参数
args = parser.parse_args()

其中的argparse模块的用法如下:

argparse模块用来编写用户友好的命令行接口。程序定义它需要的参数,argparse从sys.argv解析出那些参数。argparse模块还在用户给程序传入无效参数时报出错误信息。

当在Linux系统中复现代码时,以命令行方式运行main.py的命令要加参数后缀python main.py --data dataSet,具体的参数后缀见复现代码,代码中的参数定义如果使用argparse模块,则见本章详解

创建一个解析器

即创建一个ArgumentParser对象,ArgumentParser对象包含将命令行解析成Python数据类型所需的全部信息。
如:

import argparse

# 使用argparse的第一步--创建一个ArgumentParser对象
parser = argparse.ArgumentParser(description='Process some integers')

添加参数

给一个ArgumentParser添加程序参数信息是通过调用add_argument()方法这些调用指定ArgumentParser如何获取命令行字符串并将其转换为对象。这些信息在parse_args()调用时被存储和使用。
如:

parse.add_argument('integers', metavar='N', type=int, nargs='+')
parse.add_argument('--sum', dest='accumulate', action='store_const', const=sum, default=max)

然后再调用parse_args()将返回一个具有integersaccumulate两个属性的对象。intergers属性将是一个包含一个或者多个整数的列表,而accumulate属性当命令行中指定了–sum时,将是sum()函数;否则是max()函数。

解析参数

ArgumentParser通过parse_args()方法解析参数。它将检查命令行,把每个参数转换为适当的类型然后调用相应的操作。通常parse_args()会被不带参数调用,而ArgumentParser将自动从sys.argv中确定命令行参数。

ArgumentParser对象

argparse.ArgumentParser(prog=None, usage=None…等等参数)

  • prog:程序名称,默认为os.path.basename(sys.argv[0]),使用prog=''来设定另一个值
  • usage:ArgumentParser根据它包含的参数来构建用法消息,描述程序的用法
  • description:大多数对ArgumentParser构造方法的调用都会使用description=参数,这个参数简要描述这个程序做什么以及怎么做
  • parents:单个解析器能够通过提供parents=[]参数给ArgumentParser使用相同的参数,而不是重复这些参数的定义。parents=[]参数使用ArgumentParser对象的列表,从它们那里收集所有的位置和可选的行为,将这些行为加到正在构建的ArgumentParser对象。大多数父解析器会指定add_help=False,否则报错。在通过parents=[]传递解析器之前必须完全初始化父解析器,如果在子解析器之后改变父解析器,这些改变将不会反映在子解析器上。
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument('--parent', type=int)

foo_parser = argparse.ArgumentParser(parents=[parent_parser])

add_argument()方法

ArgumentParser.add_argument(name or flags…[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])定义单个的命令行参数应当如何解析。

  • name or flags:一个命令或者一个选项字符串的列表,以’-'为前缀是选项,剩下的参数会被假定为位置参数。选项参数可以这样被创建parser.add_argument('-f', '--foo');位置参数可以这么创建parser.add_argument('bar')
  • type:默认情况下,解析器会将命令行参数当作字符串读入。add_argument()的type关键字允许执行任何必要的类型检查和类型转换。如果type关键字使用了default关键字,则类型转换器仅会在默认值为字符串时被应用。普通内置类型和函数可被用作类型转换器。type关键字仅被应用于异常的简单转换。任何具有更复杂错误处理或资源管理的转换都应当在参数被解析后由下游代码来完成。
  • default:default值在选项未在命令行中出现时使用,即parser.parse_args([])方法[]中没有指定内容
  • required:argparse模块会认为-f和–bar是指明可选的参数,他们总是可以在命令行中被忽略。要让一个选项成为必需的,则可以将True作为required=关键字参数传给add_argument()。如果一个选项被标记为required,则当该选项未在命令行中出现时,parse_args()将会报告一个错误。
  • metavar:用来指定一个替代名称,默认情况ArgumentParser对象使用dest值作为每个对象的“name”。默认情况下,对于位置参数动作,dest值将被直接使用,而对于可选参数动作,dest值将被转为大写形式。一个位置参数dest='bar’的引用形式将为bar。一个带有单独命令行参数的可选参数–foo的引用形式将为FOO。metavar仅改变显示的名称,parse_args()对象的属性仍由dest值确定。可以提供元组给metavar,为每个参数指定不同的显示信息。

parse_args()方法

ArgumentParser.parse_args(args=None,namespace=None)将参数字符串转换为对象,并将其设为命名空间的属性。返回带有成员的命名空间。

Pytorch中Dataset与DataLoader的使用

1.介绍
Dataset对单个训练样本使用,做一些预处理,形成(x,y)数据对,以便后续DataLoader使用;DataLoader将数据组成minibatch形式,它会有多种操作可选,比如在一个周期后打乱数据、将数据固定保存在GPU中等等。

2.自定义构建Dataset来导入数据文件
一个自定义的Dataset类必须继承三个方法__init__,__len__,__getitem__
Pytorch中dataset.py中关于getitem的源码:

def __getitem__(self, index) -> T_co:
    raise NotImplementedError
    #根据参数index获取样本

自定义一个Dataset导入数据文件,其中数据文件是时间序列的表格数据,其每一行是一个样本,每一列是样本的一个特性,行与行之间有时间顺序关系。并最终转化为torch的tensor形式
举个栗子: 读取自己的数据文件FileName.csv并形成训练集

import torch
from torch.utils.data import Dataset,DataLoader,TensorDataset
import pandas as pd
import numpy as np

class TimeSeriesDataset(Dataset):
    
    def __init__(self, filepath="FilePath.FileName.csv"):
        
        df = pd.read_csv(
            filepath, header=0,
            encoding='utf-8',
            names=['feat0', 'feat1', 'feat2', 'feat3', 'feat4', 'feat5', 'feat6', 'feat7', 'feat8', 'feat9'],
            dtype=np.float32,
        )
        
        print(f"the shape of dataframe is {df.shape}")
        
        seq = df.iloc[:, 6].values
        
        self.x = torch.from_numpy(seq)
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        return self.x[index]

main():

ts_dataset = TimeSeriesDataset()
ts_dataloader = DataLoader(ts_dataset, batch_size=5000, shuffle=False)
for idx, batch_x in enumerate(ts_dataloader):
    print(f"batch_id:{idx}, {batch_x.shape}")
    print(batch_x)

在这里插入图片描述

Dataset类本质上从磁盘上读取训练数据,__getitem__函数根据索引返回数据。

以上一个一个取样本是Dataset的一种类型map-style datasets,而另一种iterable-style datasets适合流式计算场景

3.DataLoaders将一个一个的样本组成小批次minibatch
DataLoader中__init__方法的几个常用参数:

参数名含义
dataset实例化
batch_size定义大小
shuffle是否随机采样
collate_fn可选填充等后处理

dataset实例化,batch_size要定义大小,shuffle在每个epoch后随机打乱minibatch样本,collate_fn对已经shuffle好的批次进行一些后处理(如padding填充至一定长度)。并且参数sampler和shuffle是互斥的,sampler自定义采样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/791051.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3自定义指令 vue中常用自定义指令

文章目录 vue3自定义指令1.什么是自定义指令&#xff1f;2.注册自定义指令2.1 全局注册2.2 局部注册<script setup>中注册&#xff1a;<script>中使用&#xff1a; 3.钩子函数参数详解4.指令传值5.总结 常用自定义指令案例v-longpressv-debounce vue3自定义指令 除…

智能财务分析-亿发财务报表管理系统,赋能中小企业财务数字化转型

对于许多中小企业来说&#xff0c;企业重要部门往往是财务和业务部门。业务负责创收&#xff0c;财务负责控制成本&#xff0c;降低税收风险。但因管理机制和公司运行制度的原因&#xff0c;中小企业往往面临着业务与财务割裂的问题&#xff0c;财务数据不清晰&#xff0c;无法…

2023年下半年广州/深圳软考信息系统项目管理师报名

信息系统项目管理师是全国计算机技术与软件专业技术资格&#xff08;水平&#xff09;考试&#xff08;简称软考&#xff09;项目之一&#xff0c;是由国家人力资源和社会保障部、工业和信息化部共同组织的国家级考试&#xff0c;既属于国家职业资格考试&#xff0c;又是职称资…

Vue3 axios数据请求封装

Vue3 axios数据请求封装 环境&#xff1a;vue3tsvite 首先在项目目录下安装axios 运行 npm install axios 成功后在package.json文件会显示。 目录&#xff1a; request.ts文件代码&#xff1a; import axios from axiosconst request axios.create({baseURL:https://api.…

装配木牛前雷达的2023款创维汽车EV6被评为“最强主动安全车型”

近日&#xff0c;全新升级的2023款创维EV6改款车型接受了中国汽车技术研究中心&#xff08;以下简称“中汽中心”&#xff09;的安全碰撞实验。据称&#xff0c;该款车型在主动安全测试中得分率高达98.97%&#xff0c;这近满分的成绩再次刷新了国内主动安全汽车排行榜&#xff…

Android 之 Paint API —— Typeface (字型)

本节带来Paint API系列的最后一个API&#xff0c;Typeface(字型)&#xff0c;由字义&#xff0c;我们大概可以猜到&#xff0c;这个 API是用来设置字体以及字体风格的&#xff0c;使用起来也非常的简单&#xff01;下面我们来学习下Typeface的一些相关 的用法&#xff01; 官方…

右击不显示TortoiseGit图标处理方法

第一种 右键--》TortoiseGIt--》setting--》Icon Overlays--》Status cache&#xff0c;按照下图设置&#xff0c;然后重启电脑。 第二种 进入注册信息&#xff0c;按照步骤找到HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\ShellIconOverlayIden…

CS144学习笔记(1):Internet and IP

1.网络应用 网络应用可以在世界范围内交换数据&#xff0c;例如你可以通过浏览器读取出版社服务器提供的文章。网络应用的基本模型&#xff1a;两台主机各自在本地运行一个程序&#xff0c;程序通过网络来通信。 最常用的通信模型使用双向可靠字节流&#xff0c;通信的两台主机…

我愿称之为最火爆院校!学科评级A+!就业堪比清北!

一、学校及专业介绍 北京邮电大学&#xff08;Beijing University of Posts and Telecommunications&#xff09;&#xff0c;简称北邮&#xff0c;位于北京市&#xff0c;是中华人民共和国教育部直属、工业和信息化部共建的全国重点大学&#xff0c;位列国家“双一流”建设高校…

队列 的初识

Q: 什么是队列&#xff1f; A: 队列又称消息队列&#xff0c;是一种常用于任务间通信的数据结构&#xff0c;队列可以在任务与任务间、中断和任务间传递信息。 Q: 为什么不使用全局变量&#xff1f; A: 如果使用全局变量&#xff0c;任务1修改了变量 a &am…

IIC外设通信

文章目录 IIC外设简介功能介绍框图简化结构图主机发送流程主机接收流程 IIC外设简介 STM32内部集成了硬件IIC收发电路&#xff0c;可由硬件自动执行时钟生成&#xff0c;起始终止条件生成&#xff0c;应答收发位&#xff0c;数据收发等功能&#xff0c;减轻CPU负担。 功能介绍…

2023年深圳杯数学建模D题基于机理的致伤工具推断

2023年深圳杯数学建模 D题 基于机理的致伤工具推断 原题再现&#xff1a; 致伤工具的推断一直是法医工作中的热点和难点。由于作用位置、作用方式的不同&#xff0c;相同的致伤工具在人体组织上会形成不同的损伤形态&#xff0c;不同的致伤工具也可能形成相同的损伤形态。致伤…

Kubernetes ConfigMap - Secret - 使用ConfigMap来配置 Redis

目录 ConfigMap &#xff1a; 参考文档&#xff1a;k8s -- ConfigMap - 简书 (jianshu.com) K8S ConfigMap使用 - 知乎 (zhihu.com) ConfigMap的作用类型&#xff1a; 可以作为卷的数据来源&#xff1a;使用 ConfigMap 来配置 Redis | Kubernetes 可以基于文件创建 Conf…

【C++】类和对象-封装

1.属性和行为作为整体 2.示例2-设计学生类 3.访问权限 4.class和struct的区别 5.成员属性设置为私有 6.设计案例1-立方体类 在main函数前重新补上isSame函数 在Cube类里面添加issamebyclass&#xff0c;利用成员函数判断两个立方体是否相等 自己写的代码&#xff1a; #in…

开放式耳机怎么选?值得入手的开放式耳机有哪些

与封闭式耳机相比&#xff0c;开放式耳机具有更为自然、真实的音质&#xff0c;能够更好地还原音乐现场的声音环境。以下是几款值得推荐的开放式耳机&#xff0c;都来看看有哪些吧。 Top1、NANK南卡00压开放式耳机 推荐理由&#xff1a;死磕开放式传音技术&#xff0c;音质和…

(十四)InfluxDB仪表盘

以下内容来自 尚硅谷&#xff0c;写这一系列的文章&#xff0c;主要是为了方便后续自己的查看&#xff0c;不用带着个PDF找来找去的&#xff0c;太麻烦&#xff01; 第 14 章 InfluxDB仪表盘 14.1 什么是InfluxDB仪表盘 1、前面已经给大家介绍过InfluxDB的仪表盘功能了。点击…

免费数据恢复方法?这3个不要错过!

朋友们&#xff01;本人是个超级马虎的职场新手&#xff0c;在处理工作的时候总是容易误删重要的报表&#xff01;要知道我光是做一个报表就要花很长时间。大家有什么免费数据恢复的方法给我推荐推荐吗&#xff1f;感谢&#xff01;” 在使用电脑时&#xff0c;我们会在电脑中保…

性能测试工具 Jmeter 引入 jar 包踩过的坑

目录 前言&#xff1a; Jmeter 中调用自己编写 jar 中的类出错 错误日志&#xff1a; 出现以上错误的原因&#xff1a; 解决方法&#xff1a; 前言&#xff1a; JMeter 是一种开源的性能测试工具&#xff0c;可以帮助我们快速地进行网站、应用程序等的性能测试和压力测试…

SSM企业固定资产智能管理系统的设计与实现【纯干货分享,M免费领取源码06298】

摘要 信息化社会内需要与之针对性的信息获取途径&#xff0c;但是途径的扩展基本上为人们所努力的方向&#xff0c;由于站在的角度存在偏差&#xff0c;人们经常能够获得不同类型信息&#xff0c;这也是技术最为难以攻克的课题。针对企业固定资产智能管理系统等问题&#xff0c…

计算机系统结构-多处理机

概念&#xff0c;多处理机指的是&#xff0c;多台含cpu的机器共享一个存储器。 &#xff08;可以通过网络宽带&#xff0c;也可以通过线直连这个存储器。当然他们也可以有自己的私有存储器或者高速缓存&#xff09; 几个cpu公用一个总线&#xff0c;没问题。但是如果十几个cpu…