Pytorch网络模型训练

news2024/11/24 23:51:50

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别

print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别

print(vgg16_true)

# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)

print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")

#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")

# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                          download=True)

train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000

# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        # 简化操作,并且按顺序进行操作
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    loss = loss_fn(outputs, targets)

    # 优化器优化模型
    optimizer.zero_grad()
    # 将优化器中的梯度缓存(如果有的话)清零
    loss.backward()
    # 计算损失函数(loss)相对于模型参数的梯度
    optimizer.step()

    total_train_step = total_train_step + 1
    if total_train_step % 100 == 0:
        # .item()是将tensor张量变为正常的数字
        print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))
        # loss.item()是当前步骤的损失值
        writer.add_scalar("train_loss", loss.item(), total_train_step)
        # 使用add_scalar可以将一个标量添加到之前的所有标量值中,
        # 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():
    for data in test_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)
        total_test_loss = total_test_loss + loss.item()
        # argmax(1)是横向看,argmax(0)是纵向看
        accuracy = (outputs.argmax(1) == targets).sum()
        # argmax在找到模型预测的最大概率对应的类别
        # 预测正确的个数
        total_accuracy = total_accuracy + accuracy

print("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1169923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

apollo云实验:定速巡航场景仿真调试

定速巡航场景仿真调试 概述启动仿真环境仿真系统修改默认巡航速度 实验目的福利活动 主页传送门:📀 传送 概述 自动驾驶汽车在实现落地应用前,需要经历大量的道路测试来验证算法的可行性和系统的稳定性,但道路测试存在成本高昂、…

基于LDA主题+协同过滤+矩阵分解算法的智能电影推荐系统——机器学习算法应用(含python、JavaScript工程源码)+MovieLens数据集(三)

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据爬取及处理2. 模型训练及保存1)协同过滤2)矩阵分解3)LDA主题模型 3. 接口实现1)流行电影推荐2)相邻用户推荐3)相似内容推荐 相关其它博…

submit使用share buffer传参问题及解决办法

有个程序,因为处理逻辑复杂,时间长,所以启用了job。 本来是直接把耗时的核心函数 in background。 但是这种情况有个麻烦,不方便检查报错信息,只能用SM58监控。 所以改为了open job submit program。因为可以sm37查看…

selenium自动化测试入门 —— 上传文件

selenium无法识别非web的控件,上传文件窗口为系统自带,无法识别窗口元素。 上传文件有两种场景:input控制上传和非input控件上传。 大多数情况都是input控件上传文件,只有非常少数的使用自定义的非input上传文件。 一、input控…

快手群控软件可同时控制多个账号!

随着社交媒体的普及,越来越多的人开始使用快手平台来分享自己的生活、展示才艺、开展商业活动等,然而,要想在快手上获得更多的关注和收益,需要投入大量的时间和精力,这时,快手群控软件应运而生,…

2023红帽论坛:探索数字化转型的下一步

作为全球领先的企业级开源解决方案提供商,红帽的一举一动在业界都备受关注。 近日,以“探索下一步”为主题,2023红帽论坛在北京盛大揭幕。 红帽全球副总裁兼大中华区总裁曹衡康在主题演讲中表示,智能创新并不是一个单纯依赖技术的…

d3dcompiler_47.dll是什么?d3dcompiler_47.dll缺失修复教程

d3dcompiler_47.dll丢失是一个比较常见的问题,尤其是在玩一些大型游戏或者使用一些图形处理软件时。这个问题可能会让你感到非常困扰,因为它可能会导致你的程序无法正常运行。但是,别担心,我在这里为你提供了四种解决方案&#xf…

【代码】【5 二叉树】d2

关键字: 先序遍历输出第k个元素、从右向左释放所有叶子结点、二叉树高度

硬件知识积累 CAN 总线接口

1. CAN 总线协议的介绍 控制器局域网总线(CAN,Controller Area Network)是一种用于实时应用的串行通讯协议总线,它可以使用双绞线来传输信号,是世界上应用最广泛的现场总线之一。CAN协议用于汽车中各种不同元件之间的通…

微信小程序上传图片和上传视频的组件失效

微信小程序上传图片和上传视频的组件失效 今天公司的小程序展示图片和视频文字的页面上传图片组件突然失效,之前用的好好的,突然所有使用都都发现用不了,以为是代码出现问题,反复查了很久。换了一个openid居然就可以了&#xff0…

学Python,一个月从小白到大神?看你怎么学!

Python是一门超强大而且超受欢迎的编程语言。它被用在各种领域,比如网站开发、数据分析、人工智能和机器学习。学会Python会给你创造很多职业机会,所以绝对是值得一试的。 但你有没有过这样的梦想:一个月时间,从Python小白变成Py…

修改docker 版本的mysql 8.0 本机Navicat 连不上的问题

1.进入容器 docker exec -it xxxx bash 2.使用root账号登录mysql mysql -u root -p 3.查看当前加密方式 use mysql; SELECT Host, User, plugin from user; 我这是改过了,应该都是caching_sha2_password 4. 修改加密方式 ALTER USER root% IDENTIFIED WITH m…

大数据(十):数据可视化(二)

专栏介绍 结合自身经验和内部资料总结的Python教程,每天3-5章,最短1个月就能全方位的完成Python的学习并进行实战开发,学完了定能成为大佬!加油吧!卷起来! 全部文章请访问专栏:《Python全栈教…

百度百科怎么创建?百科创建需要注意哪些(一文看懂品牌/企业/人物百科创建)

随着互联网的不断发展,许多企业或品牌都选择创建百度百科作为一种很好的展示方式。百度百科可以被视为一张网络名片,拥有它能够提高人物、企业、品牌的知名度和影响力。那么人物百科、企业百科、品牌百科到底怎么创建呢? 大家创建百科前建议先…

numpy中几种随机数生成函数的用法

一、np.random.rand() 该函数括号内的参数指定的是返回结果的形状,如果不指定,那么生成的是一个浮点型的数;如果指定一个数,那么生成的是一个numpy.ndarray类型的数组;如果指定两个数字,那么生成的是一个二…

日常踩坑-[sass]Error: Expected newline

在学习sass的时候,运行时发现报错 经过网上冲浪知道,原来在声明语言的时候 lang 不能声明为 sass ,而是 scss ,这就有点坑了 原因: scss是sass3引入进来的,scss语法有"{}“,”;"而sass没有,所以…

vue3+ts封装图标选择组件

概要 讲解在vue3的项目中封装一个简单好用的图标选择组件。 效果 第一步&#xff0c;准备图标数据 数据太多&#xff0c;大家去项目中看。项目地址https://gitee.com/nideweixiaonuannuande/xt-admin-vue3 第二步&#xff0c;页面与样式编写 <template><div>…

【蓝桥杯选拔赛真题48】python最小矩阵 青少年组蓝桥杯python 选拔赛STEMA比赛真题解析

目录 python最小矩阵 一、题目要求 1、编程实现 2、输入输出 二、算法分析

电压放大器可用于什么场合

电压放大器是电子器件中常见的一种放大器类型&#xff0c;它可以将输入信号的电压放大到更大的幅度&#xff0c;以满足特定应用的需求。电压放大器广泛应用于多个领域和场合&#xff0c;下面将详细介绍一些使用电压放大器的场景。 音频放大器&#xff1a;音频放大器是电压放大器…

Python 包管理器入门指南

什么是 PIP&#xff1f; PIP 是 Python 包管理器&#xff0c;用于管理 Python 包或模块。注意&#xff1a;如果您的 Python 版本是 3.4 或更高&#xff0c;PIP 已经默认安装了。 什么是包&#xff1f; 一个包包含了一个模块所需的所有文件。模块是您可以包含在项目中的 Pyth…