1_1torch学习

news2024/12/23 13:48:11

一、torch基础知识

1、torch安装

pytorch cuda版本下载地址:https://download.pytorch.org/whl/torch_stable.html

 其中先看官网安装torch需要的cuda版本,之后安装cuda版本,之后采用pip 下载对应的torch的gpu版本whl来进行安装。使用pip安装时如果是conda需要切换到对应的env下。

2.torch.nn.functional中有很多功能。什么时候用nn.Module,什么时候用nn.functional。一般参数情况下有学习参数使用nn.Module,其他情况用nn.functional相对更简单一些。

3、一般模型在训练时会使用model.train,这样会正常使用Batch Normalization和Dropout,

     测试时一般选择model.eval(),这样就不使用Batch Normalization和Dropout

4、对于tensor,维度0表示纵轴,维度1表示横轴。

5、nn.Sequential是表示按照序列进行层运算。

6、加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可用人家训练好的权重参数来继续训练,也就是所谓的迁移学习。

需要注意的是别人训练好的任务跟咱们得可不是完全一样,需要把最后head层改一改,一般也就是最后的全连接层,改成咱们自己的任务;

训练时可以全部重头训练,也可以只训练最后任务层;

网络保存可以有选择性,选择验证集中效果最好的。   

 7、from torchvision    import transforms, models, datasets

8、transforms.Compose([

transforms.Resize([96,96]),

transforms.RandomRotation(45),

transforms.RandomRotation(45),#随机旋转,-45到45

transforms.CenterCrop(64),

transforms.HorizontalFilp,RandomVerticalFlip,ColorJitter

ToTensor, Normalize

]),

归一化:x减u除以标准差

9、model name

feature_extract = True都用人家的特征,先不更新。

model_ft = models.resnet18()

最后AdaptiveAvgPool2d(output_size=(1,1))

in_features=512, out_features=1000,bias=True

def set_parameter_requires_grad(model, feature_extracting)

        if feature_extracting:

                for param in model.parameters()#name, param in model.named_parameters()

                        param.requires_grad = False

model_ft = model.resnet18(pretrained=use_pretrained)

num_f = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_f, 102)

model_ft, input_size=initialize_model(model_name,102,feature_extract, use_pretrained=True)

#保存模型就是保存graph、parameter

filename='model.ft'
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#将需要更新的参数传进来,这里只更新最后的fc层
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1
#学习率每迭代10个epoch衰减原来的1/10
criterion = nn.CrossEntropyLoss()

def train_model(model, dataloaders, criterion, optimizer, num_epoch=25,                      
    filename="best.pt")
    best_acc = 0 #模型保存最好的
    device = 'cuda:0'
    LRs = [optimizer.param_groups[0]['lr']]
    best_model_wts = copy.deepcopy(model.state_dict())
    for inputs, labels in dataloader[phase]:
        optimizer.zero_grad()
        #只有训练的时候计算和更新梯度
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        if phase == 'train':
            loss.backward()
            optimizer.step()
        runing_loss += loss.item()
        runing_crrets += torch.sum(preds == labels.data)
    epoch_loss = running_loss/len(dataloader[phase].dataset)
       if phase == 'valid' and epoch_acc>best_acc:
            state = {
                'state_dict' :model.state_dict(),
                'best_acc': best_acc,
                'optimizer':optimizer.state_dict(),
            }
        LRs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history


    

10、训练对比

resnet18,只冻住FC层,则性能36%

resnet18,全部训练,则性能

for param in model_ft.parameters():

        param.requires_grad = True

checkpoint = torch.load(filename)

model_ft.load_state_dict(checkpoint['state_dict'])

测试数据预处理:

测试数据处理方法要跟训练时一致才可以

crop操作目的保证输入大小一致

标准化也是必须得,使用训练相同的mean和std

最后颜色通道是一个维度,很多工具包都不一样,需要转换

PIL工具包,from PIL import image

fig =  plt.figure(figsize=(20,20))

11、数据集制作

 

 

 

 

 

 

 

 

(6) 将写好的Dataset进行实例化,并实例化dataloader

 

 

 

五、工具使用jupyter

1、获取jupyter 对应kernel

import sys
sys.executable

2、为jupyter添加新的kernel,将conda通过source activate fuyao1切换到对应的环境,之后

python -m ipykernel install --user --name fuyao1

3、在jupyter中使用from argparse import SUPPRESS, ArgumentParser

args = cls._parser.parse_args(args=[])  需要在括号中添加args=[]

        cls._model_parser.add_argument(
            "--save",
            type=str,
            default="/tmp/share",
            help="save path for model or inference results, not required for resuming " "experiment.",
        )   在对应位置添加默认内容default即可。

                            

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/503327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux内核中的链表(list_head)使用分析

【摘要】本文分析了linux内核中的list_head数据结构的底层实现及其相关的各种调用源码,有助于理解内核中链表对象的使用。 二、内核中的队列/链表对象 在内核中存在4种不同类型的列表数据结构: singly-linked listssingly-linked tail queuesdoubly-lin…

SSM框架学习-bean生命周期理解

Spring启动,查找并加载需要被Spring管理的Bean,进行Bean的实例化(反射机制);利用依赖注入完成 Bean 中所有属性值的配置注入; 第一类Aware接口: 如果 Bean 实现了 BeanNameAware 接口的话&#…

Yolov8改进---注意力机制:CoTAttention,效果秒杀CBAM、SE

1.CoTAttention 论文:https://arxiv.org/pdf/2107.12292.pdf CoTAttention网络是一种用于多模态场景下的视觉问答(Visual Question Answering,VQA)任务的神经网络模型。它是在经典的注意力机制(Attention Mechanism)上进行了改进,能够自适应地对不同的视觉和语言输入进…

day28_mysql

今日内容 零、 复习昨日 一、函数[了解,会用] 二、事务[重点,理解,面试] 三、索引[重点,理解,面试] 四、存储引擎 五、数据库范式 六、其他 零、 复习昨日 见晨考 一、函数 字符串函数数学函数日期函数日期-字符串转换函数流程函数 1.1 字符串函数 函数解释CHARSET(str)返回字…

一个简单的watch以及ESModule导入和解构的区别

背景 最近写了个很有意思的方法,感觉还蛮不错的就分享一下。起先是我在写calss组件的时候遇到一个问题,我需要监听一个导入的值,导入的值最开始是undefined,经过异步操作以后会得到一个新的值,而我需要在这个class组件…

[echarts] legend icon 自定义的几种方式

echarts 官方配置项 地址 一、默认 图例项的 icon circle, rect, roundRect, triangle, diamond, pin, arrow, none legend: {top: 5%,left: center,itemWidth: 20,itemHeight: 20,data: [{icon: circle, name: 搜索引擎},{icon: rect, name: 直接访问},{icon: roundRect, n…

分布式系统---MapReduce实现(Go语言)

一、说明 本次实验是基于MIT-6.824的课程,详情请参见官网主页下载源代码 二、MapReduce原理 2.1 经典的分布式模型 MapReduce是经典的分布式模型。通过Map函数和Reduce函数实现。 分布式计算,就是利用多台机器,完成一个任务。关于分布式…

算法第一天力扣---2651. 计算列车到站时间

1.题目要求: 给你一个正整数 arrivalTime 表示列车正点到站的时间(单位:小时),另给你一个正整数 delayedTime 表示列车延误的小时数。 返回列车实际到站的时间。 注意,该问题中的时间采用 24 小时制。 示…

让ChatGPT猜你喜欢——ChatGPT后面的推荐系统

Chat GPT的大热,让人们的视线又一次聚焦于“人工智能”领域。通过与用户持续对话的形式,更加丰富的数据会不断滚动“雪球”,让Chat GPT的回答变得越来越智能,越来越接近用户最想要的答案。ChatGPT能否颠覆当下的推荐系统范式&…

第三章 灰度变换与空间滤波

第三章 灰度变换与空间滤波 3.1背景知识 ​ 空间域指图像平面本身。变换域的图像处理首先把一幅图像变换到变换域,在变换域中进行处理,然后通过反变换把处理结果返回到空间域。空间域处理主要分为灰度变换与空间滤波。 3.1.1 灰度变换和空间滤波基础 …

cmcc_simplerop

1,三连 2,IDA分析 溢出点: 偏移:0x144(错误) 这里动态重新测试了一下偏移: 正确偏移:0x20 3,找ROP 思路: 1、找系统调用号 2、ROPgadget找寄存器 3、写入/bin/sh ROPgadget --binary simpler…

7-2使用Redis构建任务队列

目录 7-2使用Redis构建任务队列 第1关:先进先出任务队列 1、rpush/lpush命令:rpush(name,values[values…]) 2、blpop:blpop(keys, timeout)和 lpop/rpop:lpop(name) 删并返回删除值 3、lpushx/rpushx:lpushx(name…

使用CKKS全同态求近似倒数(近似乘法逆元)

求倒数的算法 两个数互为倒数,是说这两个数乘起来等1.比如a和b互为倒数,那么ab1. 5的倒数是0.2,我们可以很简单的求出来,但是如何在密文域中求一个数的倒数呢? 文章《An investigation of complex operations with …

C#自适应布局

注意事项:不要在Form1中添加任何布局,页面背景不设置图片 步骤: 1、在项目中添加AutoWindowsSize.cs类,内容如下: using System; using System.Collections.Generic; using System.ComponentModel; using System.Da…

2.2 掌握 NumPy 矩阵与通用函数

2.2 掌握 NumPy 矩阵与通用函数 2.2.1 创建NumPy矩阵创建NumPy矩阵矩阵的运算矩阵的属性 2.2.2 掌握ufunc函数1、常用的ufunc函数运算2、ufunc函数的广播机制 2.2.1 创建NumPy矩阵 创建NumPy矩阵 1、使用mat函数创建矩阵: matr1 np.mat(“1 2 3;4 5 6;7 8 9”) 2…

casbin轻量级的基于配置的授权框架

简介 Casbin是一个强大的、高效的开源访问控制框架,其权限管理机制支持多种访问控制模型。 Casbin提供了一个执行者 根据提供给执行者的策略和模型文件验证传入的请求。再根据对应的配置授权策略,验证请求判断释放那些行动。 在 Casbin 中, 访问控制模…

由于找不到vcomp140.dll无法继续执行代码,解决方法全攻略

如何解决找不到vcomp140.dll错误?在使用某些软件或者游戏的时候,你可能会遇到下面的错误提示:“由于找不到vcomp140.dll,无法继续执行代码”。这个错误提示通常表示你的电脑缺少一个或多个DLL文件,而这些文件是软件和游…

「字节跳动测试开发面经」一二三面+hr面+超级全资料+复习资料

​ 说在前面,面试时最好不要虚报工资。本来字节跳动是很想去的,几轮面试也通过了,最后没offer,自己只想到几个原因: 1、虚报工资,比实际高30%; 2、有更好的人选,这个可能性不大&am…

【Linux】软件包管理器 yum和编辑器-vim的基本使用

文章目录 一、yum背景知识1.商业生态2.开源生态3.Linux软件生态本土化 二、yum的基本使用1.什么是软件包2.查看软件包3.安装软件4.卸载软件5.rzsz 三、vim的基本使用1.vim的基本概念2.vim的基本操作3.vim命令模式命令集4.vim末(底)行模式命令集5.操作总结 四、简单vim配置1.vim…

C++学习day--10 条件判断、分支

1、if语句 if 语句的三种形态 形态1&#xff1a;如果。。。那么。。。 #include <iostream> using namespace std; int main( void ) { int salary; cout << " 你月薪多少 ?" ; cin >> salary; if (salary < 20000) { cout <&…