完整模型的训练套路

news2024/9/20 17:22:21

从心所欲

不逾矩

天大地大

皆可去

一、官方模型的初使用

使用VGG16模型

 VGG模型使用代码示例:

import torchvision.models
from torch import nn

dataset = torchvision.datasets.CIFAR10('/cifar10', False, transform=torchvision.transforms.ToTensor())

vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)

# 改造VGG,增加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

# 改造vgg,修改一层
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

说明:

  1. pretrained=True:当设置为True时,模型将加载在大规模图像数据集(如ImageNet)上预训练的权重。这些预训练的权重经过了在大量图像上的训练,可以捕捉到通用的图像特征。通过加载预训练权重,可以将VGG模型初始化为在ImageNet上训练得到的状态,并且这些权重可以作为初始参数用于特定任务的微调或迁移学习。

  2. pretrained=False:当设置为False时,模型将使用随机初始化的权重。这意味着模型的权重没有经过预训练,需要从头开始进行训练。在这种情况下,模型将不会具备捕捉通用图像特征的能力,而是需要根据特定任务的数据进行训练。

pretrained=Truepretrained=False区别在于是否加载预训练的权重。如果你想要在特定任务上使用VGG模型,并且你的任务与图像分类或特征提取相关,那么通常建议使用pretrained=True,以便利用预训练权重的优势。如果你的任务与图像分类或特征提取无关,或者你希望从头开始训练模型以适应特定数据集,那么可以选择pretrained=False

二、模型的保存与加载

模型的保存:

两种保存模式,官方推荐第二种,只保存参数,不保存模型

import torch
import torchvision.models

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1: 既保存模型结构,也保存参数
torch.save(vgg16, 'vgg16_model1.pth')

# 保存方式2:把参数保存成字典,不保存结构(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_model2.pth')

print("end")

模型的加载:
 

import torch
import torchvision.models

# 加载方式1 - 保存方式1
model = torch.load('vgg16_model1.pth')

# 加载方式2 - 保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_model2.pth'))

三、完整的模型训练套路

以CIFAR10数据集来一个完整的模型训练。

代码示例:

model.py

from torch import nn


# 搭建神经网络
class Lh(nn.Module):
    def __init__(self):
        super(Lh, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

train.py

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import Lh

# 准备数据集
train_data = torchvision.datasets.CIFAR10('./cifar10', train=True, transform=torchvision.transforms.ToTensor()
                                          , download=True)
test_data = torchvision.datasets.CIFAR10('./cifar10', train=False, transform=torchvision.transforms.ToTensor()
                                         , download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))

# 利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 搭建神经网络 - 10分类
lh = Lh()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(lh.parameters(), lr=learning_rate)

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练轮数
epoch = 10

# 添加tensorboard
writer = SummaryWriter("train_logs")

for i in range(epoch):
    print("-----第{}轮训练开始了-----".format(i + 1))

    # 训练步骤开始
    for data in train_dataloader:
        imgs, tragets = data
        output = lh(imgs)
        loss = loss_fn(output, tragets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    total_test_loss = 0
    total_accuracy = 0

    with torch.no_grad():
        for data in test_dataloader:
            imgs, tragets = data
            output = lh(imgs)
            loss = loss_fn(output, tragets)
            total_test_loss += loss

            accuracy = (output.argmax(1) - - tragets).sum()
            total_accuracy += accuracy

    print("整体测试机上误差:{}".format(total_test_loss))
    print("整体测试机上的正确率:{}".format(total_accuracy / test_data_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / total_test_step)
    total_test_step += 1

    # torch.save(lh, "lhy_{}.pth".format(i))
    # print("模型已保存")

writer.close()

输出结果:

 在tensorboard打开

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/840491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

百度智能创做AI平台

家人们好,在数字化时代,人工智能正引领着一场前所未有的创新浪潮。今天,我们将为大家介绍百度智能创做AI平台,这个为创意赋能、助力创作者的强大工具。无论你是创意工作者、内容创作者,还是想要释放内心创造力的个人&a…

人工智能巨头齐聚,研究和掌控人工智能

这四家公司表示,他们成立了Frontier Model Forum,以确保"前沿AI模型的安全和负责任的开发"。 四家全球最先进的人工智能公司成立了一个研究日益强大的人工智能并建立最佳控制实践的组织,随着公众对技术影响的担忧和监管审查的增加…

Python(七十)元组的遍历

❤️ 专栏简介:本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中,我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 :本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

基于MFCC特征提取和HMM模型的语音合成算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022A 3.部分核心程序 ............................................................................ %hmm是已经…

桥接模式(C++)

定义 将抽象部分(业务功能)与实现部分(平台实现)分离,使它们都可以独立地变化。 使用场景 由于某些类型的固有的实现逻辑,使得它们具有两个变化的维度,乃至多个纬度的变化。如何应对这种“多维度的变化”?如何利用面向对象技术来使得类型…

一分钟完成centos7安装docker

action: 1、下载安装包2、安装docker 1、背景 使用CentOS / Redhat 7 版本的应该偏多。但是,Docker CE在系统中安装的时候,往往会出现一堆依赖包的报错,解决依赖包需要耗费不短的时间。 经验证,目前已找到兼容能力强的版本&am…

debug思路 - maven构建报错

问题:maven面板中,进行compile、deploy操作时报错。 debug步骤: 1、鼠标右键选择“修改运行配置”。在运行命令中添加参数-X,用于产生执行调试输出。例如:compile -f -X pom.xml。 2、再次进行compile、deploy操作&…

利用线程池多线程并发实现TCP两端通信交互,并将服务端设为守护进程

文章目录 实现目标实现步骤封装日志类封装线程池封装线程封装锁封装线程池 TCP通信的接口和注意事项accept TCP封装任务客户端Client.hppClient.cc 服务端Server.hpp Server.cc实现效果 守护进程服务端守护进程化 实现目标 利用线程池多线程并发实现基于TCP通信的多个客户端与…

《向量数据库指南》——腾讯云向量数据库Tencent Cloud VectorDB应用场景

目录 大模型知识库 推荐系统 问答系统 文本/图像检索 大模型知识库 腾讯云向量数据库可以和大语言模型 LLM 配合使用。企业的私域数据在经过文本分割、向量化后,可以存储在腾讯云向量数据库中,构建起企业专属的外部知识库,从而在后续的检索任务中,为大模型提供提示信息…

C语言每日一题:14《数据结构》复制带随机指针的链表

题目一: 题目链接: 思路一: 找相对位置暴力求解的方法: 1.复制一个新的链表出来遍历老的节点给新的节点赋值,random这个时候不去值。 2.两个链表同时遍历,遍历老链表的时候去寻找相对位置,在遍…

「Qt」常用事件介绍

🔔 在开始本文的学习之前,笔者希望读者已经阅读过《「Qt」事件概念》这篇文章了。本文会在上篇文章的基础上,进一步介绍 Qt 中一些比较常用的事件。 0、引言 当我们想要让控件收到某个事件时做一些操作,通常都需要重写相应的事件处…

Mysql如何实现XA规范

文章目录 一、什么是XA二、XA规范涉及到的角色,以及相关概念1. XA规范涉及到的角色包括:1.1 事务管理器(Transaction Manager):1.2 资源管理器(Resource Manager): 2. 相关概念包括&…

复原 IP 地址——力扣93

文章目录 题目描述回溯题目描述 回溯 class Solution{public:static constexpr int seg_count=4<

解决CF窗口黑边办法

1、桌面鼠标右键点击【英伟达控制面板】 2、点击 【调整桌面尺寸和位置】 3、缩放模式改为 【全屏】即可。 最终效果&#xff1a;

redis的配置和使用、redis的数据结构以及缓存遇见的常见问题

目录 1.缓存 2.redis不仅仅可以做缓存&#xff0c;只不过说他的大部分场景&#xff0c;是做缓存。本地缓存重启后缓存里的东西就没有了&#xff0c;但是redis有。 3.redis有几个特性:查询快&#xff0c;但是是放到内存里的〈断电或者重启&#xff0c;数据就丢了)&#xff0c…

Golang之路---04 并发编程——WaitGroup

WaitGroup 为了保证 main goroutine 在所有的 goroutine 都执行完毕后再退出&#xff0c;前面使用了 time.Sleep 这种简单的方式。 由于写的 demo 都是比较简单的&#xff0c; sleep 个 1 秒&#xff0c;我们主观上认为是够用的。 但在实际开发中&#xff0c;开发人员是无法…

神码ai伪原创工具【php源码】

大家好&#xff0c;小编为大家解答python炫酷烟花表白源代码的问题。很多人还不知道html代码烟花特效python&#xff0c;现在让我们一起来看看吧&#xff01; 火车头采集ai伪原创插件截图&#xff1a; 目录 前言 环境准备 代码编写 效果展示 前言 Python实现浪漫的烟花特效 现在…

Java并发系列之六:CountDownLatch

CountDownLatch作为开发中最常用的组件&#xff0c;今天我们来聊聊它的作用以及内部构造。 首先尝试用一句话对CountDownLatch进行概括: CountDownLatch基于AQS&#xff0c;它实现了闩锁&#xff0c;在开发中可以将其用作任务计数器。 若想要较为系统地去理解这些特性&#xff…

(十三)大数据实战——hadoop集群之YARN高可用实现自动故障转移

前言 本节内容是关于hadoop集群下yarn服务的高可用搭建&#xff0c;以及其发生故障转移的处理&#xff0c;同样需要依赖zookeeper集群的实现&#xff0c;实现该集群搭建时&#xff0c;我们要预先保证zookeeper集群是启动状态。yarn的高可用同样依赖zookeeper的临时节点及监控&…

scanf函数读取数据 清空缓冲区

scanf函数读取数据&清空缓冲区 scanf 从输入缓冲区读取数据数据的接收数据存入缓冲区scanf 中%d读取数据scanf中%c读取数据 清空输入缓冲区例子用getchar()吸收回车练习 scanf 从输入缓冲区读取数据 首先&#xff0c;要清楚的是&#xff0c;scanf在读取数据的时候&#xff…