5_现有网络模型的使用

news2024/11/15 8:00:54

 教程:现有网络模型的使用及修改_哔哩哔哩_bilibili

官方网址:https://pytorch.org/vision/stable/models.html#classification

 初识网络模型

pytorch为我们提供了许多已经构造好的网络模型,我们只要将它们加载进来,就可以直接使用。以torchvision为例,关于神经网络处理图像的模型就分为好几个大类:如图像分类、目标检测、语义分割等等。如图所示:

 视频中的讲解以VGG模型为例,来向我们展示了网络模型的使用。

因为这个教学视频也已经是两三年前了的,现在和之前略微有所区别。在这里,简单做一个说明:比如说模型加载过程中参数的改变:

如今的模型中不再有pretrained参数,也就是如果大家需要下载模型的权重文件,需要自己手动下载。务必注意,写了会报错哦。 

权重文件的下载

 视频中有讲到模型的下载也是不大不小的,如果不进行设置,一般会默认下载在c盘,想要进行设置的话,可以在网上搜索有关代码:Pytorch预训练模型下载并加载(以VGG为例)自定义路径_怎么更改vgg下载路径-CSDN博客

但以上这位同学的方法我使用时出错,提示我没有这个属性:

model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
AttributeError: module 'torch.utils.model_zoo' has no attribute '_download_url_to_file'

所以我略加修改,以下是我的处理下载过程,同样出错的同学可以看看:

from urllib.parse import urlparse
import torch
# import re
import os
def download_model(url, dst_path):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    
    # HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
    # hash_prefix = HASH_REGEX.search(filename).group(1)
    
    torch.hub.download_url_to_file(url, os.path.join(dst_path, filename))
    return filename



path = "D:\\vscodeProjects\\models"
if not (os.path.exists(path)):
    os.makedirs(path)
url='https://download.pytorch.org/models/vgg16-397923af.pth'
download_model(url, path)

 只是这个下载的速度着实太慢,我先放弃了:

 关于这个权重文件的下载我犯了一点小迷糊。我有点搞不懂为什么费劲巴拉下载这么大个东西然后视频中又仅仅使用vgg16=torchvision.models.vgg16()这一句话就完事了。

于是我搜索了一下:

  • 在 PyTorch 中,许多流行的深度学习模型(如 VGG、ResNet、AlexNet 等)都有预先训练好的权重文件可供下载。这些权重文件包含了模型在大规模数据集(如 ImageNet)上训练的参数,可以帮助加快模型的收敛速度,提升模型的表现。下载预训练模型通常是为了避免从头开始训练模型,节省时间和计算资源。
  • torchvision.models 是 PyTorch 提供的一个模块,用于加载常见的计算机视觉模型,例如 VGG、ResNet、AlexNet 等。这些模型可以通过简单的调用来导入,并且可以选择加载预训练的权重。

 简而言之,权重文件可以简化我们模型的训练过程,我们可以通过使用权重文件来直接利用前辈的训练结果,稍作修改就可以变成我们自己的东西。

如果只是用vgg16=torchvision.models.vgg16()这么一句话来加载网络模型,得到的模型只有结构而没有经过训练的过程,因此它的权重是初始的。

网络模型的修改

因为官网中提到的VGG模型的官配数据集ImageNet实在是太大了(100+个G),笔记本实在带不了,所以还是使用我们之前已经用了很多次的数据集CIFAR10来搞,正好可以讲解一下怎样修改网络模型。

原官配数据集非常之大(对我一个初学者来说,是暂时见过最大的数据集了),最终一共分为1000个类。因此这个VGG模型最终输出为1000,为了适配于我们这个CIFAR10数据集(输出只有10类),我们为加载下来的VGG模型添加一个线性层,将原本的1000个类最终输出为10类。

from torch import nn
import torchvision
vgg16=torchvision.models.vgg16()
train_data=torchvision.datasets.CIFAR10("../dataset",train=True,transform=torchvision.transforms.ToTensor())
vgg16.add_module('add_linear',nn.Linear(1000,10))

print(vgg16)可以看到,最下面就是我们新添加的层:

 如果我们想添加在classifier这个模型中,我们也可以这样写:

vgg16.classifier.add_module('add_linear',nn.Linear(1000,10))

同样打印一下看效果:

 当然如果我们不想添加新的一层,我们也可以通过另外的一种方式来将输出从1000改为10:

如上图所示,已知最后一层是线性层,输入4096,输出1000,那么我们现在直接将最后一个线性层修改,输出改成10:

vgg16.classifier[6]=nn.Linear(in_features=4096,out_features=10,bias=True)

看结果:

模型的保存和加载

如果我们对网络模型进行了修改或者训练,如何将我们自己的模型保存下来呢?一共有以下两种方式:

vgg16=torchvision.models.vgg16()
vgg16.classifier[6]=nn.Linear(in_features=4096,out_features=10,bias=True)
#保存方式一:保存权重文件和模型结构
torch.save(vgg16,"vgg16_method1.pth")
#保存方式二(官方推荐),实际上保存的是权重文件,以字典方式存储:
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

而如果我们想要取出我们已经保存的模型,就可以:

#方式一加载保存的模型
vgg16_method1=torch.load("vgg16_method1.pth")
#方式二加载保存的权重文件
vgg16_method2=torch.load("vgg16_method2.pth")
vgg16=torchvision.models.vgg16()
vgg16.load_state_dict(vgg16_method2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1982335.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【CONDA】库冲突解决办法

如今,使用PYTHON作为开发语言时,或多或少都会使用到conda。安装Annaconda时一般都会选择在启动终端时进入conda的base环境。该操作,实际上是在~/.bashrc中添加如下脚本: # >>> conda initialize >>> # !! Cont…

python:YOLO格式数据集图片和标注信息查看器

作者:CSDN _养乐多_ 本文将介绍如何实现一个可视化图片和标签信息的查看器,代码使用python实现。点击下一张和上一张可以切换图片。 文章目录 一、脚本界面二、完整代码 一、脚本界面 界面如下图所示, 二、完整代码 使用代码时&#xff0…

无线WiFi破解原理(超详细)

大家应该都有过这样的经历,就是感觉自己家的无线网怎么感觉好像变慢了,"是不是有人蹭我家网?""还有的时候咱们出门也想试图蹭一下别人家的网",这里"蹭网"的前提是要破解对方的"无线密码"…

SQL注入复现1-18关

一、联合查询(1-4关) 首先打开第一关查看源代码,他的闭合方式为 找到闭合方式后,我们就可以使用order by来确定列数 我们可以看到使用order by 4--回车时报错,使用order by 3--时显示,所以我们就得到他得列…

微信丨QQ丨TIM防撤回工具

适用于 Windows 下 PC 版微信/QQ/TIM的防撤回补丁。支持最新版微信/QQ/TIM,其中微信能够选择安装多开功能。微信防撤回信息! 「防撤回」来自UC网盘分享https://drive.uc.cn/s/95f9aabbc9684

2024年起重机司机(限桥式起重机)证模拟考试题库及起重机司机(限桥式起重机)理论考试试题

题库来源:安全生产模拟考试一点通公众号小程序 2024年起重机司机(限桥式起重机)证模拟考试题库及起重机司机(限桥式起重机)理论考试试题是由安全生产模拟考试一点通提供,起重机司机(限桥式起重机)证模拟考试题库是根据起重机司机(限桥式起重机)最新版教…

elasticsearch教程

1. 单点部署(rpm): #提前关闭firewalld,否则无法组建集群 #1. 下载ES rpm包 ]# https://www.elastic.co/cn/downloads #2. 安装es ]# rpm -ivh elasticsearch-7.17.5-x86_64.rpm #3. 调整内核参数(太低的话es会启动报错) echo "vm.max_map_count655360 fs.file-max 655…

MySQL1 DDL语言

安装与配置 官网: MySQL :: Download MySQL Installer 阿里云: MySQL8 https://www.alipan.com/s/auhN4pTqpRp 点击链接保存,或者复制本段内容,打开「阿里云盘」APP ,无需下载极速在线查看,视频原画倍速…

外卖项目day14(day11)---数据统计

Apache ECharts 大家可以看我这篇文章: Apache ECharts-CSDN博客 营业额统计 产品原型 接口设计 新建admin/ReportController /*** 数据统计相关接口*/ RestController RequestMapping("/admin/report") Api(tags "数据统计相关接口") Slf…

快速解密哈希算法利器Hasher:解密MD5、SHA256、SHA512、RIPEMD160等最佳工具

文章目录 一、工具概述1.1主要功能点1.2 支持多种哈希算法 二、安装方法三、使用教程四、结语 一、工具概述 Hasher 是一个哈希破解工具,支持多达 7 种类型的哈希算法,包括 MD4、MD5、SHA1、SHA224、SHA256、SHA384、SHA512 等。它具有自动检测哈希类型、支持 Windows 和 Linux…

浙大阿里联合开源AudioLCM,在通用音频合成领域实现潜在一致性模型的新突破...

文本到通用音频生成(Text-to-Audio Generation,简称 TTA)作为生成任务的一个子领域,涵盖了音效创作、音乐创作和合成语音,具有广泛的应用潜力。在此前的神经 TTA 模型中,潜在扩散模型(Latent Di…

【RHEL7】无人值守安装系统

目录 一、kickstart服务 1.下载kickstart 2.启动图形制作工具 3.选择设置 4.查看生成的文件 5.修改ks.cfg文件 二、HTTP服务 1.下载HTTP服务 2.启动HTTP服务 3.将挂载文件和ks.cfg放在HTTP默认目录下 4.测试HTTP服务 三、PXE 1.查看pxe需要安装什么 2.安装 四、…

批量按照原图片名排序修改图片格式为00000001.png(附代码)

💪 专业从事且热爱图像处理,图像处理专栏更新如下👇: 📝《图像去噪》 📝《超分辨率重建》 📝《语义分割》 📝《风格迁移》 📝《目标检测》 📝《暗光增强》 &a…

ARMxy工控机使用Node-Red教程:安装工具和依赖(2)

2.3 工具安装 Node-Red 安装过程需要用到网络。请通过网线将设备千兆网口 ETH1 连接至互联网,确保可正常访问互联网。 Node-Red 是一个基于Node的可视化编程工具,因此需要先安装Node。为了便于测试,我司提供的 node-v16.14.0-linux-arm64.t…

原神升级计划数据表:4个倒计时可以修改提示信息和时间,可以点击等级、命座、天赋、备注进行修改。

<!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><title>原神倒计时</title><style>* {margin: 0;padding: 0;box-sizing: border-box;body {background: #0b1b2c;}}header {width: 100vw;heigh…

「队列」实现FIFO队列(先进先出队列|queue)的功能 / 手撕数据结构(C++)

概述 队列&#xff0c;是一种基本的数据结构&#xff0c;也是一种数据适配器。它在底层上以链表方法实现。 队列的显著特点是他的添加元素与删除元素操作&#xff1a;先加入的元素总是被先弹出。 一个队列应该应该是这样的&#xff1a; --------------QUEUE-------------——…

大数据资源平台建设可行性研究方案(58页PPT)

方案介绍: 在当今信息化高速发展的时代&#xff0c;大数据已成为推动各行各业创新与转型的关键力量。为了充分利用大数据的潜在价值&#xff0c;构建一个高效、安全、可扩展的大数据资源平台显得尤为重要。通过本方案的实施企业可以显著提升数据处理能力、优化资源配置、促进业…

SQL注入实例(sqli-labs/less-8)

0、初始页面 1、确定闭合字符 ?id1 and 11 ?id1 and 12 ?id1 ?id1 and 11 -- ?id1 and 12 -- 确定闭合字符为单引号&#xff0c;并且正确页面与错误页面的显示不同 2、爆库名 使用python脚本 def inject_database1(url):name for i in range(1, 20):low 32high 1…

【大模型从入门到精通5】openAI API高级内容审核-1

这里写目录标题 高级内容审核利用 OpenAI 内容审核 API 的高级内容审核技术整合与实施使用自定义规则增强审核综合示例防止提示注入的策略使用分隔符隔离命令理解分隔符使用分隔符实现命令隔离 高级内容审核 利用 OpenAI 内容审核 API 的高级内容审核技术 OpenAI 内容审核 AP…

SQL注入漏洞复现1

一、靶场信息 sqli-labs下载&#xff1a;https://github.com/Audi-1/sqli-labs phpstudy下载地址&#xff1a;http://down.php.cn/PhpStudy20180211.zip 我是在本地安装小皮搭建环境&#xff0c;相比于在服务器上搭建环境&#xff0c;更加简单 二、注入实操 Less-1 爆库名…