pytorch进阶学习(四):使用不同分类模型进行数据训练(alexnet、resnet、vgg等)

news2024/11/25 22:30:42

课程资源:5、帮各位写好了十多个分类模型,直接运行即可【小学生都会的Pytorch】_哔哩哔哩_bilibili

 

目录

一、项目介绍

 1. 数据集准备

2. 运行CreateDataset.py

3. 运行TrainModal.py 

4. 如何切换显卡型号

二、代码

1. CreateDataset.py

2.TrainModal.py 

3. 运行结果


一、项目介绍

 1. 数据集准备

数据集在data文件夹下

 

2. 运行CreateDataset.py

运行CreateDataset.py来生成train.txt和test.txt的数据集文件。

 

3. 运行TrainModal.py 

进行模型的训练,从torchvision中的models模块import了alexnet, vgg, resnet的多个网络模型,使用时直接取消注释掉响应的代码即可,比如我现在训练的是vgg11的网络。

    # 不使用预训练参数
   # model = alexnet(pretrained=False, num_classes=5).to(device) # 29.3%

    '''        VGG系列    '''
    model = vgg11(weights=False, num_classes=5).to(device)   #  23.1%
    # model = vgg13(weights=False, num_classes=5).to(device)   # 30.0%
    # model = vgg16(weights=False, num_classes=5).to(device)


    '''        ResNet系列    '''
    # model = resnet18(weights=False, num_classes=5).to(device)    # 43.6%
    # model = resnet34(weights=False, num_classes=5).to(device)
    # model = resnet50(weights= False, num_classes=5).to(device)
    #model = resnet101(weights=False, num_classes=5).to(device)   #  26.2%
    # model = resnet152(weights=False, num_classes=5).to(device)

 同时需要注意的是, vgg11中的weights参数设置为false,我们进入到vgg的定义页发现weights即为是否使用预训练参数,设置为FALSE说明我们不使用预训练参数,因为vgg网络的预训练类别数为1000,而我们自己的数据集没有那么多类,因此不使用预训练。

 

把最后一行中产生的pth的文件名称改成对应网络的名称,如model_vgg11.pth。 

    # 保存训练好的模型
    torch.save(model.state_dict(), "model_vgg11.pth")
    print("Saved PyTorch Model Success!")

4. 如何切换显卡型号

我在运行过程中遇到了“torch.cuda.OutOfMemoryError”的问题,显卡的显存不够,在服务器中使用查看显卡占用情况命令:

nvidia -smi

可以看到我一直用的是默认显卡0,使用情况已经到了100%,但是显卡1使用了67%,还用显存使用空间,所以使用以下代码来把显卡0换成显卡1.

# 设置显卡型号为1
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

 

二、代码

1. CreateDataset.py

'''
生成训练集和测试集,保存在txt文件中
'''
##相当于模型的输入。后面做数据加载器dataload的时候从里面读他的数据
import os
import random#打乱数据用的

# 百分之60用来当训练集
train_ratio = 0.6

# 用来当测试集
test_ratio = 1-train_ratio

rootdata = r"data"  #数据的根目录

train_list, test_list = [],[]#读取里面每一类的类别
data_list = []

#生产train.txt和test.txt
class_flag = -1
for a,b,c in os.walk(rootdata):
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
        test_list.append(test_data)

    class_flag += 1

print(train_list)
random.shuffle(train_list)#打乱次序
random.shuffle(test_list)

with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

2.TrainModal.py 

'''
    加载pytorch自带的模型,从头训练自己的数据
'''
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from utils import LoadData

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


from torchvision.models import alexnet  #最简单的模型
from torchvision.models import vgg11, vgg13, vgg16, vgg19   # VGG系列
from torchvision.models import resnet18, resnet34,resnet50, resnet101, resnet152    # ResNet系列
from torchvision.models import inception_v3     # Inception 系列

# 定义训练函数,需要
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # 从数据加载器中读取batch(一次读取多少张,即批次数),X(图片数据),y(图片真实标签)。
    for batch, (X, y) in enumerate(dataloader):
        # 将数据存到显卡
        X, y = X.cuda(), y.cuda()

        # 得到预测的结果pred
        pred = model(X)

        # 计算预测的误差
        # print(pred,y)
        loss = loss_fn(pred, y)

        # 反向传播,更新模型参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 每训练10次,输出一次当前信息
        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model):
    size = len(dataloader.dataset)
    # 将模型转为验证模式
    model.eval()
    # 初始化test_loss 和 correct, 用来统计每次的误差
    test_loss, correct = 0, 0
    # 测试时模型参数不用更新,所以no_gard()
    # 非训练, 推理期用到
    with torch.no_grad():
        # 加载数据加载器,得到里面的X(图片数据)和y(真实标签)
        for X, y in dataloader:
            # 将数据转到GPU
            X, y = X.cuda(), y.cuda()
            # 将图片传入到模型当中就,得到预测的值pred
            pred = model(X)
            # 计算预测值pred和真实值y的差距
            test_loss += loss_fn(pred, y).item()
            # 统计预测正确的个数
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"correct = {correct}, Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")




if __name__=='__main__':
    batch_size = 8

    # # 给训练集和测试集分别创建一个数据集加载器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)


    train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    # 如果显卡可用,则用显卡进行训练
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")


    '''
        随着模型的加深,需要训练的模型参数量增加,相同的训练次数下模型训练准确率起来得更慢
    '''

    # 不使用预训练参数
   # model = alexnet(pretrained=False, num_classes=5).to(device) # 29.3%

    '''        VGG系列    '''
    model = vgg11(weights=False, num_classes=5).to(device)   #  23.1%
    # model = vgg13(weights=False, num_classes=5).to(device)   # 30.0%
    # model = vgg16(weights=False, num_classes=5).to(device)


    '''        ResNet系列    '''
    # model = resnet18(weights=False, num_classes=5).to(device)    # 43.6%
    # model = resnet34(weights=False, num_classes=5).to(device)
    # model = resnet50(weights= False, num_classes=5).to(device)
    #model = resnet101(weights=False, num_classes=5).to(device)   #  26.2%
    # model = resnet152(weights=False, num_classes=5).to(device)




    print(model)
    # 定义损失函数,计算相差多少,交叉熵,
    loss_fn = nn.CrossEntropyLoss()

    # 定义优化器,用来训练时候优化模型参数,随机梯度下降法
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)  # 初始学习率


    # 一共训练1次
    epochs = 1
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        time_start = time.time()
        train(train_dataloader, model, loss_fn, optimizer)
        time_end = time.time()
        print(f"train time: {(time_end-time_start)}")
        test(test_dataloader, model)
    print("Done!")

    # 保存训练好的模型
    torch.save(model.state_dict(), "model_vgg11.pth")
    print("Saved PyTorch Model Success!")

3. 运行结果

vgg11的运行结果:,可以看到最后的准确率是24.8%,因为没有用预训练模型,所以准确率很低。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/429460.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何基于ChatGPT+Avatar搭建24小时无人直播间

0 前言 最近朋友圈以及身边很多朋友都在研究GPT开发,做了各种各样的小工具小Demo,AI工具用起来是真的香!在他们的影响下,我也继续捣鼓GPT Demo,希望更多的开发者加入一起多多交流。 上一篇结合即时通 IM SDK捣鼓了一个…

因为这三个面试题,我与字节offer失之交臂

我一个朋友挑战3个月入职字节,一路过关斩将直到终面,着实把我惊了一把,可惜的是,他倒在了最后三个面试题上。 我很讶异,前面不是打得很好吗?怎么会在最后几题上犯错误呢? 朋友说:别…

电瓶隔离器工作原理与发展简史

电瓶隔离器(Battery Isolators)工作原理与发展简史 电池隔离器(英文:Battery Isolators),又叫双电池隔离器、双电瓶隔离器、双电瓶保护器,还有叫双电池分离器的。 电瓶隔离器其实并没有真正的隔离,负极是始终连在一起的。房车、…

拓展系统命令

文章目录拓展系统命令使用方式拓展系统命令快速运行方法命令 - ZFASTRUN安全运行方法命令 - ZFASTSAFERUN快速运行Query方法命令 - ZFASTQUERY安全运行Query方法 命令 - ZSAFEQUARY防止调试时误将数据提交命令 - ZTRN在Terminal执行SQL语句命令 - ZSQL安全Global命令 - ZSAFEKI…

动态内存管理【上篇】

文章目录⚙️1.为什么存在动态内存分配⚙️2.动态内存函数的介绍📬2.1. malloc函数📬2.2. free函数📬2.3. calloc函数📬2.4. realloc函数⚙️3.常见的动态内存错误🔒3.1.对NULL指针的解引用操作🔒3.2.对动态…

二叉树(OJ)

单值二叉树(力扣) ---------------------------------------------------哆啦A梦的任意门------------------------------------------------------- 我们来看一下题目的具体要求: 既然我们都学了二叉树了,我们就应该学会如何去…

笔记:Java关于轻量级锁与重量级锁之间的问答

问题:如果在轻量级锁状态下出现锁竞争,不一定会直接升级为重量级锁,而是会先尝试自旋获取锁,那么有a b两个线程竞争锁,a成功获取锁了,b就一定失败,那么轻量级锁就一定升级为重量级锁&#xff0c…

基于Bazel + SQLFluff实现SQL lint

背景SQL进行版本化控制后,我们希望为SQL加入lint步骤。这样做的好处是我们可以在真正执行SQL前发现问题。本文中,我们通过Bazel执行SQLFluff[1]以实现SQL的lint。SQLFluff是一款使用Python语言使用的,支持SQL多方言的SQL lint工具。它的特点是…

设计模式-创建型模式之单例模式

6.单例模式6.1. 模式动机对于系统中的某些类来说,只有一个实例很重要,例如,一个系统中可以存在多个打印任务,但是只能有一个正在工作的任务;一个系统只能有一个窗口管理器或文件系统;一个系统只能有一个计时…

360安全卫士退出企业安全云模式

360安全卫士退出企业安全云模式前言360企业安全云关闭企业安全云提醒退出企业安全云模式前言 360安全卫士推出了企业安全云,并会给个人版用户进行推送,虽然可以关闭,但有可能会不小心升级为企业安全云,用户可能并不不习惯&#x…

2023铜鼓半马5月14日开跑,4月18日启动报名!

长寿铜鼓,康养胜地!众翼电气2023铜鼓半程马拉松暨英雄马系列赛(铜鼓站)新闻发布会今日召开,铜鼓县委常委、宣传部部长熊涛,铜鼓县教育体育局党委书记、局长孙桃基,铜鼓县文广新旅局党组书记、局…

SpringBoot API 接口防刷

SpringBoot API 接口防刷接口防刷接口防刷原理代码实现RequestLimit 注解RequestLimitIntercept 拦截器WebMvcConfig配置类Controller控制层验证接口防刷 接口防刷: 顾名思义,想让某个接口某个人在某段时间内只能请求N次。 在项目中比较常见的问题也有,…

【Python】Python程序中使用request库连接外国网站的方法

确认你的socks端口: 然后程序可以这么写: import requests import socks import socket# 创建 SOCKS5 代理连接 socks.set_default_proxy(socks.SOCKS5, "127.0.0.1", 10808) socket.socket socks.socksocket# 发送请求 response request…

Java高级特性 - 多线程基础(2)常用函数【第1关:线程的状态与调度 第2关:常用函数(一)第3关:常用函数(二)】

目录 第1关:线程的状态与调度 第2关:常用函数(一) 第3关:常用函数(二) 第1关:线程的状态与调度 相关知识 为了完成本关你需要掌握: 1.线程的状态与调度&#xff1b…

Linux内核中常用的数据结构和算法

文章目录链表红黑树无锁环形缓冲区Linux内核代码中广泛使用了数据结构和算法,其中最常用的两个是链表和红黑树。 链表 Linux内核代码大量使用了链表这种数据结构。链表是在解决数组不能动态扩展这个缺陷而产生的一种数据结构。链表所包含的元素可以动态创建并插入和…

APP自动化测试(14)-利用xpath定位元素

一、元素定位的困难 定位元素时有时无法准确定位到我们想要的元素,存在如下几种情况 1、通过一个条件无法准确定位到元素,需要进行条件组合 2、某元素无法唯一定位到,但是同级的其他元素可以唯一定位 3、某元素的属性无论如何组合都无法唯…

训练机器学习模型,可使用 Sklearn 提供的 16 个数据集 【下篇】

数据是机器学习算法的动力,scikit-learn或sklearn提供了高质量的数据集,被研究人员、从业人员和爱好者广泛使用。Scikit-learn(sklearn)是一个建立在SciPy之上的机器学习的Python模块。它的独特之处在于其拥有大量的算法、十分易用…

AOP使用场景记录总结(缓慢补充更新中)

测试项目结构: 目前是测试两个日志记录和 代码的性能测试 后面如果有其他的应用场景了在添加.其实一中就包括了二,但是没事,多练一遍 1. 日志记录 比如说对service层中的所有增加,删除,修改方法添加日志, 记录内容包括操作的时间 操作的方法, 方法的参数, 方法所在的类, 方法…

CSS :autofill 如何覆盖浏览器自动填充表单的样式

CSS :autofill 如何覆盖浏览器自动填充表单的样式 :autofill 伪类匹配浏览器自动填充值的 input 元素. 如果用户继续编辑这个元素内容就会停止匹配. #name:autofill {background-color: red !important;border: 6px solid red; } #name:-webkit-autofill {background-color: …

OpenAI-ChatGPT最新官方接口《审核机制》全网最详细中英文实用指南和教程,助你零基础快速轻松掌握全新技术(七)(附源码)

Moderation 审核机制前言Introduction 导言Quickstart 快速开始其它资料下载ChatGPT 作为一个大型人工智能语言模型,在提供用户便捷交流的同时也承担着内容审核的责任。为了保护用户和社会免受不良信息的影响,ChatGPT 特别注重关于内容的审核。当用户发送…