PyTorch(六)网络模型

news2025/1/10 20:21:34

文章目录

    • Log
  • 一、现有网络模型的使用及修改
    • 1. VGG
      • ① ImageNet 数据集下载
      • ② 模型加载
      • ③ 模型改造
        • a. 添加一个线性层
        • b. 插入一个线性层
        • c. 修改一个线性层
  • 二、网络模型的保存与读取
    • ① 网络模型的保存
      • a. 保存方式一
      • b. 保存方式二
    • ② 网络模型的读取
      • a. 读取方式一
      • b. 读取方式二
  • 三、完整的模型训练套路
  • 四、利用 GPU 训练
    • 1. 方法一
    • 2. 方法二
  • 五、完整的模型验证套路
  • 总结


Log

2022.12.11有点感受,不好描述,但是似乎是体会到了生活中的一点东西,充满激情,充满希望,勇敢面对,所以干脆今天就把剩下的部分学完得了(满口胡话)
2022.12.12剩了一点,今天来解决掉


一、现有网络模型的使用及修改

  • 语音相关的模型在 torchaudio 中
  • 文字相关的模型在 torchtext 中
  • 图像相关的模型在 torchvision 中
    • 分类:Classification
    • 语义分割:Semantic Segmentation
    • 目标检测,实例分割和人物关键点检测:Object Detection, Instance Segmentation and Person Keypoint Detection
    • 视频分类:Video Classification

1. VGG

  • torchvision 中最常用的分类模型 VGG 有 VGG11、VGG13、VGG16、VGG19,其中最常用的是 VGG16 和 VGG19。

① ImageNet 数据集下载

  • 想要使用 torchvision.datasets.ImageNet 数据集需要提前安装 scipy
  • 试图下载该数据集:
train_data = torchvision.datasets.ImageNet("../dataset/ImageNet", split='train', download=True,
                                           transform=torchvision.transforms.ToTensor())
  • 该数据集已经无法公开访问,需要手动下载,并且大小100G+,所以直接放弃。

② 模型加载

  • 预训练参数为 False 时,加载的是初始的参数,为 True 时则是训练好的能够达到较好效果的参数。
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
  • 输出查看网络结构:
print(vgg16_true)

在这里插入图片描述在这里插入图片描述

  • 可以看到最后的的输出 out_features=1000,即最后可以识别 1000 个分类,在 ImageNet ILSVRC2012 中也可以看到对应共有 1000 个类别:
    在这里插入图片描述

③ 模型改造

  • 现在我们使用上面加载好的模型对 C I F A R 10 \rm CIFAR10 CIFAR10 进行分类,下面是三种不同的方法:

a. 添加一个线性层

  • 我们可以在原网络的基础上再添加一个线性层:输入 1000 个特征,输出 10 个特征
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
  • 通过观察输出的网络结构可以发现我们在网络的最后添加了一个线性层 add_linear(与 classifier 在同一级上):
    在这里插入图片描述

b. 插入一个线性层

  • 对上面的代码稍作修改,就可以将新加入的线性层插入到原有的结构中:
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
  • 输出的网络结构如下( add_linear 包含在 classifier 中):
    在这里插入图片描述

c. 修改一个线性层

  • 直接对原有的模型进行修改,将输出特征数改为 10:
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
  • 原有结构如下:
    在这里插入图片描述

  • 修改的结构如下:
    在这里插入图片描述

二、网络模型的保存与读取

① 网络模型的保存

  • 首先加载未经训练的模型:
vgg16 = torchvision.models.vgg16(pretrained=False)

a. 保存方式一

torch.save(vgg16, "../models/VGG/vgg16_method1.pth")
  • 两个参数分别是要保存的模型和保存路径,保存的模型文件一般是 .pth 格式的。
  • 该方法不仅保存了网络模型的结构,也保存了网络模型的参数。

b. 保存方式二

torch.save(vgg16.state_dict(), "../models/VGG/vgg16_method2.pth")
  • 该方法不再保存网络的结构,而是以字典的形式保存网络模型的参数,对应上面的第一个参数,第二个参数则是保存的路径(该方法下模型文件的大小要比方式一稍小一些)

② 网络模型的读取

a. 读取方式一

  • 对应上面的保存方式一
model = torch.load("../models/VGG/vgg16_method1.pth")
  • 如果保存的是自定义的网络,那么在读取时必须要导入原有的网络定义的类才能够成功加载。

b. 读取方式二

  • 对应上面的保存方式二
  • 由于只保存了参数,所以我们需要重建网络模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("../models/VGG/vgg16_method2.pth"))

三、完整的模型训练套路

  • 加载数据集:
train_data = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../dataset/CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
  • 获取数据集长度:
train_data_size = len(train_data)
test_data_size = len(test_data)
  • 利用 DataLoader 来加载数据集:
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • 搭建神经网络:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x
        
mo = MyModel()
  • 定义损失函数:
loss_fn = nn.CrossEntropyLoss()
  • 定义优化器:
learning_rate = 1e-2
optimizer = torch.optim.SGD(mo.parameters(), lr=learning_rate)
  • 设置训练网络的参数:
# 训练的次数
total_train_step = 0
# 测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10
  • 训练开始:
for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # 训练步骤开始
    mo.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = mo(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))

    # 测试步骤开始
    mo.eval()
    total_test_loss = 0
    total_accuracy = 0
    # 此处没有用到梯度
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = mo(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    total_test_step = total_test_step + 1

    torch.save(mo, "model_{}.pth".format(i))
    print("模型已保存")
  • 其中在计算正确率的时候用到了 argmax 函数。
  • 使用 tensorboard:
writer = SummaryWriter("logs_train")

for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # 训练步骤开始
    mo.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = mo(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    mo.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = mo(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step = total_test_step + 1

    torch.save(mo, "model_{}.pth".format(i))
    print("模型已保存")

writer.close()
  • 记录结果如下:
    在这里插入图片描述

四、利用 GPU 训练

  • 咱的电脑是 AMD 的,所以这部分就简单记录一下。

1. 方法一

  • 在上面定义的过程后面添加对应的方法即可:
# 模型定义
mo = MyModel()
mo.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn.cuda()
# 训练和测试过程的中
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
  • 在使用 cuda 方法之前最好加入以下代码进行判断:
if torch.cuda.is_available():

2. 方法二

  • 定义设备并调用 to 方法:
# 定义训练的设备
device = torch.device("cuda")
# 模型定义
mo = MyModel()
mo.to(device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
# 训练和测试过程的中
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)
  • 如果有多个显卡的话定义设备时可以采用以下方法:
device = torch.device("cuda:0")
device = torch.device("cuda:1")

五、完整的模型验证套路

  • 利用训练好的模型进行测试。
import torch
import torchvision
from PIL import Image
from torch import nn

image_path = "../images/dog.jpg"
image = Image.open(image_path)
print(image)
image = image.convert('RGB')
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x
# 想在 CPU 上运行 GPU 跑出来的模型时添加第二个参数
model = torch.load("tudui_29_gpu.pth", map_location=torch.device('cpu'))
print(model)
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
    output = model(image)
print(output)

print(output.argmax(1))


总结

  • 本篇文章介绍了如何对已有的模型进行修改或者添加自己想要的结构,保存的读取网络模型的方法,利用 GPU 进行训练,以及完整的模型训练和验证的套路。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/82947.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信公众号服务号配置对接在线客服系统教程

如果只需要实现微信公众号的关注自动回复,关键词自动回复功能,普通订阅号就可以 当需要对接实现公众号的模板消息提醒,模板消息与客服端H5的对接,访客在微信点击或扫码时获取到微信的昵称头像,需要网页授权功能。这种是…

Spring(一):Spring核心与设计思想(IoC、DI)

目录一、Spring是什么1.1 容器是什么?1.2 什么是IoC?1.3 理解Spring IoC1.4 DI是什么一、Spring是什么 我们这里所说的Spring指的是SpringFrameWork,是一个开源框架。Spring支持广泛的应用场景,它可以让Java企业级的应用程序开发…

k8s编程operator实战之云编码平台——②controller初步实现

文章目录1、工作空间镜像制作2、controller实现2.1 使用kubebuilder创建工程2.2 代码实现2.2.1 引入grpc2.2.2 实现CloudIdeServiceStatusInformer的实现CloudSpaceService定义方法CreateSpaceAndWaitForRunning方法GetPodSpaceInfo方法DeleteSpace方法GetPodSpaceStatus2.2.3 …

人才盘点的工具与方法有哪些?怎样做好人才盘点?

人才盘点是对组织和人才进行系统管理的一种流程。在此过程中,对组织架构、人员配比、人才绩效、关键岗位的继任计划、关键人才发展、关键岗位的招聘及关键人才的晋升和激励进行深入讨论,并制定详细的组织行动计划,确保组织以更加优化的结构和…

非零基础自学计算机操作系统 第1章 操作系统概述 1.5 操作系统的硬件环境 1.5.1 定时装置 1.5.2 堆与栈 1.5.3 寄存器

非零基础自学计算机操作系统 文章目录非零基础自学计算机操作系统第1章 操作系统概述1.5 操作系统的硬件环境1.5.1 定时装置1.5.2 堆与栈1.5.3 寄存器第1章 操作系统概述 1.5 操作系统的硬件环境 构建一个高效、可靠的操作系统,硬件需要提供哪些支持? 1…

MySQL数据库基本使用(一)-------登录及查看基本信息

1.MySQL登录命令 格式如下: mysql -h 主机名 -P 端口号 -u 用户名 -p密码例如: mysql -h localhost -P 3306 -u root -pabc123 # 这里我设置的root用户的密码是abc123注意: -p与密码之间不能有空格,其他参数名与参数值之间可以…

Spring Boot启动原理源码

Spring Boot启动原理源码 注意:这个springboot启动源码和springboot自动配置原理的源码是十分重要的,面试的时候要是问springboot,一般都会问这两个。 源码: SpringBoot 事假监听器发布顺序: 1.ApplicationStartingEvent 2.ApplicationEnvironmentPrepa…

springboot+mybatis配置多数据源实战

1.背景说明 2.配置多数据源步骤 2.1 项目结构变更 2.2 添加配置类 2.3 修改配置文件数据连接配置信息 1.背景说明一般一个项目中只会连接一个数据库.但是随着需求变更,会要求同一个项目中连接多个数据库,本文就讲一下如何在一个项目中对多…

usaco training刷怪旅 第一层第二题 Greedy Gift Givers

usaco training 关注我持续创作training题解 翻译有点奇葩,我就上原题目了,各位自己翻译吧QwQ A group of NP (2 ≤ NP ≤ 10) uniquely named friends has decided to exchange gifts of money. Each of these friends might or might not give some m…

一种基于PCI总线的反射内存卡设计

一种基于PCI总线的反射内存卡设计 摘要: 对实时传输, 传统的以太网络由于传输协议开销的不确定性, 很难满足实时网络的要求, 实时网络是一种应用于高实时性要求的专用网络通信技术, 一般采用基于高速网络的共享存储器…

Python爬虫实战,requests+openpyxl模块,爬取小说数据并保存txt文档(附源码)

前言 今天给大家介绍的是Python爬取小说数据并保存txt文档,在这里给需要的小伙伴们代码,并且给出一点小心得。 首先是爬取之前应该尽可能伪装成浏览器而不被识别出来是爬虫,基本的是加请求头,但是这样的纯文本数据爬取的人会很多…

web网页设计与开发:基于HTML+CSS+JavaScript简单的个人博客网页制作期末作业

🎉精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

怎么让文字转换成语音?一步一步让你学会

在日常的生活中,我们经常会需要将文字转换成语音的情况,例如广告词、给文本配音等等,当然我们就简单的方法就是自己手动进行配音,但是如果没有专业的设备和配音环境,是很难配出很好的效果的,这该怎么办呢&a…

飞链云智能机器人-基于ChatGPT的有趣问答

最近ChatGPT火起来了; 可玩性很高,不亚于之前AI绘画的视觉冲击;这次ChatGPT带来的是逻辑冲击;上下文逻辑远超现有市面上其他所有的AI对话机器人; 有人用技巧训练ChatGPT,ChatGPT机器人宣言要毁灭人类&…

备战2023蓝桥国赛-传纸条

题目描述: 解析: 这道题想了我好久,一开始我是想假如只走一条路线,从(1,1)走到(m,n),这种问题该怎么解决呢?针对这种问题我是设了dp[k][i][j]表示走了k步到达(i,j)的好心程度之和的…

[附源码]JAVA毕业设计迎宾酒店管理系统录屏(系统+LW)

[附源码]JAVA毕业设计迎宾酒店管理系统录屏(系统LW) 项目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目…

R语言中使用多重聚合预测算法(MAPA)进行时间序列分析

最近我们被客户要求撰写关于时间序列分析的研究报告,包括一些图形和统计输出。这是一个简短的演示,可以使用该代码进行操作。使用MAPA生成预测。 > mapasimple(admissions)t1 t2 t3 t4 t5 t6 t7 t8 t9 t…

ElasticsearchRestTemplate 和ElasticsearchRepository 的使用

操作ElasticSearch的数据,有两种方式一种是 ElasticsearchRepository 接口,另一种是ElasticsearchTemplate接口 SpringData对ES的封装ElasticsearchRestTemplate类,可直接使用,此类在ElasticsearchRestTemplate基础上进行性一定程…

Kibana:使用 Maps 来显示分布式的团队

在我之前的文章 “Kibana:如何在 Maps 应用中显示图片提示” 里,我展示了如何在 Kibana 中使用图片来展示一个图片的提示。这个在很多情况下是非常有用的,比如在疫情发生期间,我可以通过点击地图上的点来查看发生疫情人员的详细情…

ADI Blackfin DSP处理器-BF533的开发详解40:图像处理专题-GrayStretch 图像的灰度拉伸(含源码)

硬件准备 ADSP-EDU-BF533:BF533开发板 AD-HP530ICE:ADI DSP仿真器 软件准备 Visual DSP软件 硬件链接 功能介绍 代码实现了图像的灰度拉伸,代码运行时,会通过文件系统打开工程文件根目下" …/ImageView"路径中的 t…