1_1torch基础知识

news2024/12/25 13:08:47

1、torch安装

pytorch cuda版本下载地址:https://download.pytorch.org/whl/torch_stable.html

 其中先看官网安装torch需要的cuda版本,之后安装cuda版本,之后采用pip 下载对应的torch的gpu版本whl来进行安装。使用pip安装时如果是conda需要切换到对应的env下。

2、tensor创建

(1)创建不同类型的tensor

#torch 基础知识
import torch
#创建不同类型的tensor
print("创建不同类型的tensor")
a_float = torch.Tensor([1,2,3])
print("a_float=", a_float)
a_float2 = torch.FloatTensor([1,2,3]).zero_()#增加了zero_表示原地计算绝对值,返回原值,zero则计算一个新的tensor结果
print("a_float2=", a_float2)
a_int = torch.IntTensor([1,2,3])
print("a_int =", a_int)
a_double = torch.DoubleTensor(1,2)
print("a_double =", a_double)
print("后面还有ByteTensor(unsigned 8 bit integer),CharTensor(signed 8 bit integer), ShortTensor(16 bit integer), LongTensor(64 bit integer)")

=============================结果=======================================

创建不同类型的tensor
a_float= tensor([1., 2., 3.])
a_float2= tensor([0., 0., 0.])
a_int = tensor([1, 2, 3], dtype=torch.int32)
a_double = tensor([[4.4743e-316, 4.4757e-316]], dtype=torch.float64)
后面还有ByteTensor(unsigned 8 bit integer),CharTensor(signed 8 bit integer), ShortTensor(16 bit integer), LongTensor(64 bit integer)

(2)通过不同形式输入创建tensor

#通过不同输入创建tensor
print("通过不同输入创建tensor,size, *size, sequence, ndarray,tensor,storage")
a_size = torch.IntTensor(2,3)
print("a_size =", a_size)
a_size2 = torch.Tensor(*[1,2,3])
print("a_size2=",a_size2)
a_sequence = torch.Tensor([1,2,3])
print("a_sequence =", a_sequence)
======================================结果========================================
通过不同输入创建tensor,size, *size, sequence, ndarray,tensor,storage
a_size = tensor([[408112416,     32605,  90579216],
        [        0,        32,         0]], dtype=torch.int32)
a_size2= tensor([[[1.2311e-35, 0.0000e+00, 1.1751e-35],
         [0.0000e+00, 8.9683e-44, 0.0000e+00]]])
a_sequence = tensor([1., 2., 3.])

(3)torch.Tensor()与torch.tensor()的区别

        torch.Tensor()是一个类,是默认张量类型torch.FloatTensor()的别名,用于生成一个单精度浮点类型的张量
        torch.tensor()这里是小写,仅仅是一个python函数,函数原型是torch.tensor(data, dtype=None, device=None, require_grad=False),其中data可以是list、tuple,numpy,ndarray等其他类型,torch.tensor会从data中数据部分拷贝而不是直接引用,根据数据类型生成相应类型的torch.Tensor

a_Tensor = torch.Tensor([1,2,3])
print("a_Tensor =", a_Tensor)
a_tensor = torch.tensor([1,2,3])
print("a_tensor =", a_tensor)

==================================结果====================================

Torch.Tensor()与torch.tensor()的区别
a_Tensor = tensor([1., 2., 3.])
a_tensor = tensor([1, 2, 3])

(4)Tensor类型间的转换

CPU和GPU的Tensor之间转换
data.cuda():cpu –> gpu
data.cpu():gpu –> cpu
Tensor与Numpy Array之间的转换
data.numpy():Tensor –> Numpy.ndarray
torch.from_numpy(data):Numpy.ndarray –> Tensor
Tensor的基本类型转换
tensor.long():
tensor.half():将tensor投射为半精度浮点(16位浮点)类型
tensor.int():
tensor.double():
tensor.float():
tensor.char():
tensor.byte():
tensor.short():
Tensor的基本数据类型转换
type(dtype=None, non_blocking=False, **kwargs):指定类型改变。例如data = data.type(torch.float32)
type_as(tensor):按照给定的tensor的类型转换类型。

#Tensor类型间的转换
print("Tensor类型间的转换")
#1 转换tensor的类型
a_int = torch.IntTensor([1,2,3])
print("a_int =", a_int)
a_float = a_int.type(torch.float)
print("a_float =", a_float)

b_int = torch.IntTensor([6,6,6])
print("b_int =", b_int)
b_float = b_int.type_as(a_float)
print("b_float =", b_float)

b_float2 = b_int.float()
print("b_float2 =", b_float2)

==========================结果=================================

Tensor类型间的转换
a_int = tensor([1, 2, 3], dtype=torch.int32)
a_float = tensor([1., 2., 3.])
b_int = tensor([6, 6, 6], dtype=torch.int32)
b_float = tensor([6., 6., 6.])
b_float2 = tensor([6., 6., 6.])

2.torch.nn.functional中有很多功能。什么时候用nn.Module,什么时候用nn.functional。一般参数情况下有学习参数使用nn.Module,其他情况用nn.functional相对更简单一些。

3、一般模型在训练时会使用model.train,这样会正常使用Batch Normalization和Dropout,

     测试时一般选择model.eval(),这样就不使用Batch Normalization和Dropout

4、对于tensor,维度0表示纵轴,维度1表示横轴。

5、nn.Sequential是表示按照序列进行层运算。

6、加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可用人家训练好的权重参数来继续训练,也就是所谓的迁移学习。

需要注意的是别人训练好的任务跟咱们得可不是完全一样,需要把最后head层改一改,一般也就是最后的全连接层,改成咱们自己的任务;

训练时可以全部重头训练,也可以只训练最后任务层;

网络保存可以有选择性,选择验证集中效果最好的。   

 7、from torchvision    import transforms, models, datasets

8、transforms.Compose([

transforms.Resize([96,96]),

transforms.RandomRotation(45),

transforms.RandomRotation(45),#随机旋转,-45到45

transforms.CenterCrop(64),

transforms.HorizontalFilp,RandomVerticalFlip,ColorJitter

ToTensor, Normalize

]),

归一化:x减u除以标准差

9、model name

feature_extract = True都用人家的特征,先不更新。

model_ft = models.resnet18()

最后AdaptiveAvgPool2d(output_size=(1,1))

in_features=512, out_features=1000,bias=True

def set_parameter_requires_grad(model, feature_extracting)

        if feature_extracting:

                for param in model.parameters()#name, param in model.named_parameters()

                        param.requires_grad = False

model_ft = model.resnet18(pretrained=use_pretrained)

num_f = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_f, 102)

model_ft, input_size=initialize_model(model_name,102,feature_extract, use_pretrained=True)

#保存模型就是保存graph、parameter

filename='model.ft'
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#将需要更新的参数传进来,这里只更新最后的fc层
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1
#学习率每迭代10个epoch衰减原来的1/10
criterion = nn.CrossEntropyLoss()

def train_model(model, dataloaders, criterion, optimizer, num_epoch=25,                      
    filename="best.pt")
    best_acc = 0 #模型保存最好的
    device = 'cuda:0'
    LRs = [optimizer.param_groups[0]['lr']]
    best_model_wts = copy.deepcopy(model.state_dict())
    for inputs, labels in dataloader[phase]:
        optimizer.zero_grad()
        #只有训练的时候计算和更新梯度
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        if phase == 'train':
            loss.backward()
            optimizer.step()
        runing_loss += loss.item()
        runing_crrets += torch.sum(preds == labels.data)
    epoch_loss = running_loss/len(dataloader[phase].dataset)
       if phase == 'valid' and epoch_acc>best_acc:
            state = {
                'state_dict' :model.state_dict(),
                'best_acc': best_acc,
                'optimizer':optimizer.state_dict(),
            }
        LRs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history


    

10、训练对比

resnet18,只冻住FC层,则性能36%

resnet18,全部训练,则性能

for param in model_ft.parameters():

        param.requires_grad = True

checkpoint = torch.load(filename)

model_ft.load_state_dict(checkpoint['state_dict'])

测试数据预处理:

测试数据处理方法要跟训练时一致才可以

crop操作目的保证输入大小一致

标准化也是必须得,使用训练相同的mean和std

最后颜色通道是一个维度,很多工具包都不一样,需要转换

PIL工具包,from PIL import image

fig =  plt.figure(figsize=(20,20))

11、数据集制作

 

 

 

 

 

 

 

 

(6) 将写好的Dataset进行实例化,并实例化dataloader

 

 

 

                            

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/515881.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux】安装部署elasticsearch

安装 Java 在安装 Elasticsearch 之前,您需安装并配置好 JDK, 设置好环境变量 $JAVA_HOME。 众所周知,Elasticsearch 版本很多,不同的版本对 Java 的依赖也有所差别: Elasticsearch 5 需要 Java 8 以上版本;Elasticsearch 6.5 开…

旋转目标检测【1】如何设计深度学习模型

前言 平常的目标检测是平行的矩形框,“方方正正”的;但对于一些特殊场景(遥感),需要倾斜的框,才能更好贴近物体,旋转目标检测来啦~ 一、如何定义旋转框 常见的水平框参数表达方式为&#xff0…

PMP项目管理-[第九章]资源管理

资源管理知识体系: 规划资源管理: 估算活动资源: 获取资源: 建设团队: 管理团队: 9.1 规划资源管理 定义:定义如何估算、获取、管理和利用团队以及实物资源的过程 作用:根据项目类型…

Azure Data Lake Storage Gen2 简介

Azure Data Lake Storage Gen2 基于 Azure Blob 存储构建,是一套用于大数据分析的功能。 Azure Data Lake Storage Gen1 和 Azure Blob Storage 的功能在 Data Lake Storage Gen2 中组合在一起。例如,Data Lake Storage Gen2 提供规模、文件级安全性和文…

Cesium入门之三:隐藏Cesium初始化页面小部件的两种方法

上一级我们实现了第一个三维地球,但是在这个页面上有很多小部件,我们不想让其显示,应该如何设置呢?这一节我们通过两种方式来隐藏小部件 方法1:通过js代码实现 在js代码中,通过在new Cesium.Viewer(conta…

算法之路--直接插入排序算法

在介绍插入排序算法之前,先举证一个我们都熟悉不过的例子即可理解插入排序。我们在打牌的时候,由于每次抽到的牌大小不一,为了在出牌时了解自己手里都还剩什么牌型,所以每次对抽到的新牌都会做一个排序,怎么比较呢&…

AWS 中的另外一种远程工具 AWS Session Manager

作者:SRE运维博客 博客地址:https://www.cnsre.cn/ 文章地址:https://www.cnsre.cn/posts/230129126154/ 相关话题:https://www.cnsre.cn/tags/aws/ 背景需求 因为项目的安全性。为了避免项目的服务器暴露在公网中。很多时候我们…

设计原则之【迪米特法则】,非礼勿近

文章目录 一、什么是迪米特法则1、理解迪米特法则2、如何理解“高内聚、松耦合”? 二、实例1、实例12、实例2 一、什么是迪米特法则 迪米特原则(Law of Demeter LoD)是指一个对象应该对其他对象保持最少的了解,又叫最少知道原则&…

支付系统设计三:渠道网关设计01-总览

文章目录 前言一、开发框架二、E-R图三、管理后台配置四、运行时执行流程五、屏蔽渠道差异总结 前言 在《支付系统设计一:支付系统产品化》文章中,我们知道支付渠道网关主要具有以下功能: 统一支付出口,提供丰富的支付工具原子能…

详解:扫雷游戏的实现

扫雷游戏的实现 设置雷排查雷标记雷打印棋盘初始化棋盘获得雷的个数扩展区域test.c的实现meni.c的实现meni.h的实现 铁汁们,今天给大家分享一篇扫雷游戏的实现,来吧,开造⛳️ 1.需要存储雷的信息,创建二维数组来存储数据信息&…

gateway的使用

什么是Spring Cloud Gateway 网关作为流量的入口,常用的功能包括路由转发,权限校验,限流等。 Spring Cloud Gateway 是Spring Cloud官方推出的第二代网关框架,定位于取代 Netflix Zuul。相比 Zuul 来说,Spring Cloud …

第1章 量化设计与分析基础

1.1 引言 如今一台价格不足500美元的手机,性能便堪比1993年世界上最快的售价5000万美元的计算机,这种快速发展既得益于计算机制造技术的发展,也得益于计算机设计的创新。 纵观计算机发展的历史,技术一直在稳定地提升&#xff0c…

【COT】Chain-of-Thought Prompting Elicits Reasoning in Large Language Models

文章目录 主要解决什么问题采用什么方法实验分析与结果消融实验Commonsense ReasoningSymbolic Reasoning 问题与展望 Chain-of-Thought Prompting Elicits Reasoning in Large Language Models 主要解决什么问题 大语言模型越来越大,效果越来越好。但是在一些具有…

Sui Builder House首尔站|主网上线后首次亮相

Sui Builder House提供与全球Sui构建者会面、合作并学习Sui平台前沿技术的机会。Sui基金会计划将于2023年在全球12个城市举办Sui Builder House。截止目前,已成功在美国丹佛市、越南胡志明市和中国香港举办三场Builder House活动。 Sui Builder House首尔站将于6月…

【D435i深度相机YOLO V5结合实现目标检测】

【D435i深度相机YOLO V5结合实现目标检测】 1. 前言2 分析2.1 关于yolo部分2.2 关于获取三维坐标的部分2.3 关于文件结构部分 3. 代码 1. 前言 参考:Ubutntu下使用realsense d435i(三):使用yolo v5测量目标物中心点三维坐标 欢迎大…

满意度指标- NPS 的ABtest(公式推导)

👉A 组的NPS如下 👉B 组的NPS如下 (下标为1,均为A组的样本数据;下标为2,均为B组的样本数据) 要验证A\B两组的NPS差异是否显著,可通过假设检验。 一、假设检验前置准备 1. 选择…

FastDFS理论与Java程序的对接(图片,文件上传)

目录 fastdfs概述Java程序对接fastDFSpom配置java启动类注解yaml文件配置controller类service类 fastdfs概述 什么是分布式文件系统? 是指文件系统管理的物理存储资源不一定直接连接在本地节点上,而是通过计算机与节点相连。 通俗来讲: 传统…

【medini analyze 软件介绍】

medini analyze 软件介绍 简介主要功能(功能安全范畴)1、HARA2、建模3、FMEA & FMEDA4、FTA*这里只是笔者根据汽车功能安全分析经验简单列举了medini analyze的部分功能,完整的功能介绍请参考ANSYS官网* 简介 medini analyze是一款专业的…

怎么把pdf文件压缩到最小?四招快速压缩!

怎么把pdf文件压缩到最小?平常我们要压缩一个文件,一般都知道该如何操作。系统中自带了压缩工具,只需右键点击需要压缩的对象并选择"压缩"选项即可完成操作。然而,很多人也会发现,尽管大部分的文件都可以通过…

网络安全公司Dragos披露网络安全事件

工业网络安全公司 Dragos 披露了它所称的“网络安全事件”,此前一个已知的网络犯罪团伙试图突破其防御并渗透到内部网络以加密设备。 虽然 Dragos 表示威胁行为者没有破坏其网络或网络安全平台,但他们可以访问公司的 SharePoint 云服务和合同管理系统。…