Pytorch自定义数据集模型训练流程

news2024/11/25 20:28:55

文章目录

  • Pytorch模型自定义数据集训练流程
    • 1、任务描述
    • 2、导入各种需要用到的包
    • 3、分割数据集
    • 4、将数据转成pytorch标准的DataLoader输入格式
    • 5、导入预训练模型,并修改分类层
    • 6、开始模型训练
    • 7、利用训好的模型做预测

Pytorch模型自定义数据集训练流程

我们以kaggle竞赛中的猫狗大战数据集为例搭建Pytorch自定义数据集模型训练的完整流程。

1、任务描述

Cats vs. Dogs(猫狗大战)数据集是Kaggle大数据竞赛某一年的一道赛题,利用给定的数据集,用算法实现猫和狗的识别。 其中包含了训练集和测试集,训练集中猫和狗的图片数量都是12500张且按顺序排序,测试集中猫和狗混合乱序图片一共12500张。
下载地址:https://www.kaggle.com/c/dogs-vs-cats/data

在这里插入图片描述

卷积神经网络(CNN)是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习的代表算法之一。卷积神经网络具有表征学习能力,能够按其阶层结构对输入信息进行平移不变分类,因此也被称为“平移不变人工神经网络”。
默认对图像分类各种算法已经熟悉,卷积、池化、批量归一化、全连接等各种结构具体细节这里不讨论,有不懂的需自行学习。

2、导入各种需要用到的包

import torch
import torchvision
from torchvision import datasets, transforms
import torch.utils.data
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch import nn
import numpy as np
import os
import shutil
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

3、分割数据集

下载猫狗大战数据集,并解压。
解压完成后,通过以下代码实现数据集预处理(剔除不能正常打开的图片,打乱数据集等);然后对数据集进行分割,其中90%的数据集作为train训练,另外10%的数据集作为test测试。

# 分割数据集,将全部数据分成0.9的Train和0.1的Test
source_path = r"./kagglecatsanddogs_5340/PetImages/"
# 如果不存在文件夹要新建一个
if not os.path.exists(os.path.join(source_path, "train")):
    os.mkdir(os.path.join(source_path, "train"))
train_dir = os.path.join(source_path, "train")

if not os.path.exists(os.path.join(source_path, "test")):
    os.mkdir(os.path.join(source_path, "test"))
test_dir = os.path.join(source_path,"test")

## 将Cat和Dog文件夹全部移到train目录下,然后再从train目录下移动10%到test目录下
for category_dir in os.listdir(source_path):
    if category_dir not in ["train", "test"]:
        shutil.move(os.path.join(source_path,category_dir), os.path.join(source_path,"train"))
            
## 开始移动,移动前先剔除不能正常打开的图片
for dir in os.listdir(train_dir):
    category_dir_path = os.path.join(train_dir, dir)
    image_file_list = os.listdir(category_dir_path)   # 取出全部图片文件
    for file in image_file_list:
        try:
            Image.open(os.path.join(category_dir_path, file))
        except:
            os.remove(os.path.join(category_dir_path, file))
            image_file_list.remove(file)
    np.random.shuffle(image_file_list)
    test_num = int(0.1*len(image_file_list))
 
    #移动10%文件到对应目录
    if not os.path.exists(os.path.join(test_dir,dir)):
        os.mkdir(os.path.join(test_dir,dir))
    if len(os.listdir(os.path.join(test_dir,dir))) < test_num:  # 只有未移动过才需要移动,否则每运行一次都会移动一下
        for i in range(test_num):
            shutil.move(os.path.join(category_dir_path,image_file_list[i]), os.path.join(test_dir,dir,image_file_list[i]))

4、将数据转成pytorch标准的DataLoader输入格式

1、先对数据集进行预处理,包括resize成224*224的尺寸,因为vgg_net模型需要的输入尺寸为[N, 224, 224, 3];随机翻转,随机旋转等,另外对数据集做Normalize标准化,其中的mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.2]是从ImageNet数据集上的百万张图片中随机抽样计算得到的,以上这些内容主要是数据增强,增强模型的泛化性,有更好的预测效果。
2、然后将预处理好的数据转成pytorch标准的DataLoader输入格式,。

# 数据预处理
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),# 对图像进行随机的crop以后再resize成固定大小
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456,
0.406],std=[0.229, 0.224, 0.2]),  # ImageNet全部图片的平均值和标准差
    transforms.RandomRotation(20), # 随机旋转角度
    transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
])
 
# 读取数据
root = source_path
train_dataset = datasets.ImageFolder(root + '/train', transform)
test_dataset = datasets.ImageFolder(root + '/test', transform)
 
# 导入数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

5、导入预训练模型,并修改分类层

1、定义device,如果有GPU模型训练会自动用GPU训练,否则会使用CPU;使用GPU训练,只需在模型、数据、损失函数上使用cuda()就行。
2、这边默认对分类图像算法都熟悉,可以自己构建vgg16的完整网络,在猫狗数据集上重新训练。也可以下载预训练模型,由于原网络的分类输出是1000类别的,但是我们的图片只有两类,所以需要修改分类层,让模型能够适配我们的训练数据集。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
print(vgg16)

inputs = torch.rand(1, 3, 224, 224)  # 拿一个随机tensor测试一下网络的输出是否满足预期
output = vgg16(inputs.to(device))
print("原始VGG网络的输出:",output.size())

# 构建新的全连接层
vgg16.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2)).to(device)
inputs = torch.rand(1, 3, 224, 224)
output = vgg16(inputs.to(device))
print("新构建的VGG网络的输出:",output.size())

6、开始模型训练

开始模型训练,我们这里只训练全连接分类层,将特征层的梯度requires_grad设置为False,特征层的参数将不参与训练。
训练过程中保存效果最好的网络模型,以防掉线,可以从断点开始继续训练,同时也可以用来做预测。
训练完成后,保存训练好的网络和参数,后面可以加载模型做预测。

writer = SummaryWriter("./logs/model")
loss_func = nn.CrossEntropyLoss().to(device)
learning_rate = 0.0001

#如果我们想只训练模型的全连接层
for param in vgg16.features.parameters():
    param.requires_grad = False
optimizer = torch.optim.Adam(vgg16.parameters(),lr=learning_rate)

##训练开始
total_train_step = 0
total_test_step = 0
min_acc = 100.0
for epoch in range(10):
    print("-----------train epoch {} start---------------".format(epoch))
    vgg16.train()
    for data in train_loader:
        optimizer.zero_grad()
        img, label = data
        output = vgg16(img.to(device))
        loss = loss_func(output, label.to(device))
        loss.backward()
        optimizer.step()
        total_train_step += 1
        
        if total_train_step % 10 == 0:
            print("steps: {}, train_loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)


    ## 测试开始,看训练效果是否满足预期
    total_test_loss = 0
    total_acc = 0.0
    vgg16.eval()
    with torch.no_grad():
        for data in test_loader:
            optimizer.zero_grad()
            img, label = data
            output = vgg16(img.to(device))
            loss = loss_func(output, label.to(device))
            total_test_loss += loss
            accuary = torch.sum(output.argmax(1) == label.to(device))
            total_acc += accuary
    total_test_step += 1
    val_acc = total_acc.item() / len(test_dataset)
    
    ## 保存Acc最小的模型
    if val_acc < min_acc:
        min_acc = val_acc
        torch.save(vgg16.state_dict(), "./models/2classes_vgg16_weight_{}_{}.pth".format(epoch, round(val_acc,4)))
        torch.save(vgg16, "./models/2classes_vgg16_{}_{}.pth".format(epoch, round(val_acc,4)))

    print("测试loss: {}".format(total_test_loss.item()))
    print("测试Acc: {}".format(val_acc))
    writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)
    writer.add_scalar("test_Acc", val_acc, total_test_step)

torch.save(vgg16.state_dict(), "./models/2classes_vgg16_latest_{}.pth".format(val_acc))

7、利用训好的模型做预测

拿出一张图片做预测,首先导入预训练模型,同样改掉分类层,然后导入预训练权重,预测图片类别,输出标签值和预测类别。

import matplotlib.pyplot as plt
img_path = r"./kagglecatsanddogs_5340/PetImages/test/Cat/1381.jpg"   # 拿出要预测的图片
image = Image.open(img_path).convert("RGB")
image.show()
    
vgg16_pred = torchvision.models.vgg16(pretrained=True)
vgg16_pred.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2))

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224), interpolation=2),
    torchvision.transforms.ToTensor()
])
vgg16_pred.load_state_dict(torch.load("./models/2classes_vgg16_weight_15_0.9467513434294089.pth", map_location=torch.device('cpu')))
print(vgg16_pred)

image = transform(image)
print(image.size())
image = torch.reshape(image, [1,3,224,224])
vgg16_pred.eval()
with torch.no_grad():
    output = vgg16_pred(image)
# print("预测值为:",output)
print("预测标签为:",output.argmax(1).item())
print("预测动物为:",train_dataset.classes[output.argmax(1)])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/167531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

响应式流的核心机制——背压机制

一、响应式流是什么&#xff1f; Reactive Streams 是 2013 年底由 Netflix、Lightbend 和 Pivotal&#xff08;Spring 背后的公司&#xff09;的工程师发起的一项计划&#xff0c;响应式流旨在为无阻塞异步流处理提供一个标准。它旨在解决处理元素流的问题——如何将元素流从…

【BP靶场portswigger-客户端14】点击劫持-5个实验(全)

前言&#xff1a; 介绍&#xff1a; 博主&#xff1a;网络安全领域狂热爱好者&#xff08;承诺在CSDN永久无偿分享文章&#xff09;。 殊荣&#xff1a;CSDN网络安全领域优质创作者&#xff0c;2022年双十一业务安全保卫战-某厂第一名&#xff0c;某厂特邀数字业务安全研究员&…

Fastdfs分布式文件系统原理浅析

文章目录1、fastdfs文件系统原理简述2、storage server状态2.1 组内新增加一台storage server A时&#xff0c;由系统自动完成已有数据同步&#xff0c;处理逻辑如下&#xff1a;第一步&#xff1a;第二步&#xff1a;第三步&#xff1a;第四步&#xff1a;3、同步时间管理4、B…

[有人@你]请查收你的年终总结报告

嗨&#xff0c;兄dei&#xff0c;我是建模助手。 新年伊始&#xff0c;最近大家想必已经被各大平台的2022年度报告刷屏了。 听歌软件伴你度过的失眠夜&#xff0c;外卖软件拯救你的饥饿时刻&#xff0c;还有某俩宝账单告诉你&#xff0c;其实你是有钱的&#xff0c;只是你看不到…

基于有向图的邻接矩阵计算其割点、割边、压缩图,并用networkx可视化绘制

基于有向图的邻接矩阵计算其割点、割边、压缩图&#xff0c;并用networkx可视化绘制为什么基于邻接矩阵计算图的割点、割边、压缩图实现python代码代码运行效果结论&#xff1a;为什么基于邻接矩阵计算图的割点、割边、压缩图 由于矩阵计算过程&#xff0c;被广泛优化&#xf…

Linux关于 gdb 调试器的使用

坚持看完&#xff0c;结尾有思维导图总结 这里写目录标题debug 和 release 版本gdb 常见命令断点逐行调试和观察变量总结debug 和 release 版本 首先要说的是 &#xff0c;在 Linux 中 gcc 直接编译是不能进行调试的 而是要在加上 -g 选项才能得到可调试的文件 以下程序用一个…

算法第十二期——BFS-双向广搜

双向广搜 应用场景&#xff1a;有确定的起点s和终点t&#xff1b;把从起点到终点的单向搜索&#xff0c;变换为分别从起点出发和从终点出发的“相遇”问题。操作&#xff1a;从起点s(正向搜索&#xff09;和终点t(逆向搜索&#xff09;同时开始搜索&#xff0c;当两个搜索产生…

Spring入门-Spring事务管理

文章目录1&#xff0c;Spring事务管理1.1 Spring事务简介1.1.1 相关概念介绍1.1.2 转账案例-需求分析1.1.3 转账案例-环境搭建步骤1:准备数据库表步骤2:创建项目导入jar包步骤3:根据表创建模型类步骤4:创建Dao接口步骤5:创建Service接口和实现类步骤6:添加jdbc.properties文件步…

数据治理与档案信息资源体系建设

如果要评选大数据或者数字化转型领域中哪个词最让人费解、最讲不清楚&#xff0c;“数据治理&#xff08;Data Governance&#xff09;”绝对是候选之一。说实话&#xff0c;笔者到现在也没有完全整明白&#xff0c;因为数据治理包含的范围太广了&#xff0c;可以说是包罗万象&…

高潜人才的自我要求

前言&#xff0c;上次写了个《潜力出众的你有这样的特质吗&#xff1f;》&#xff0c;地址如下&#xff1a;点我查看&#xff0c;这次在写个高潜人才的自我要求。本次以6个纬度来进行分析&#xff1b;3是基本要求&#xff0c;4是追求卓越&#xff0c;看你目前做到了哪个级别&am…

跨平台API对接(Python)的使用

Jenkins 是一个开源的、提供友好操作界面的持续集成(CI)工具&#xff0c;起源于 Hudson&#xff08;Hudson 是商用的&#xff09;&#xff0c;主要用于持续、自动的构建/测试软件项目、监控外部任务的运行。后端可以利用 Jenkins 对任务进行调度运行&#xff1a;后端可利用 HTT…

【进阶】Spring更简单的读取和存储对象

努力经营当下&#xff0c;直至未来明朗&#xff01; 文章目录一、存储Bean对象一&#xff09;前置工作&#xff1a;配置扫描路径&#xff08;重要&#xff09;二&#xff09;添加注解存储Bean对象3. 五大类注解&#xff1a;4. 方法注解&#xff1a;6. 相关问题7. 补充【结论、查…

ROS2机器人编程简述humble-第二章-DEVELOPING THE FIRST NODE .2

0.1ROS2机器人编程简述新书推荐-A Concise Introduction to Robot Programming with ROS21.1ROS2机器人编程简述humble-第一章-Introduction2.1ROS2机器人编程简述humble-第二章-First Steps with ROS2 .12.2主要内容是全手工创建一个最简单的自定义节点&#xff0c;其实没啥具…

IB学生必看的时间表(二)

上期谈到在IB预科课程的第一个学年下学期&#xff0c;便要开始作报读大学的准备&#xff0c;到底为什么&#xff1f; 暑假不容松懈 现在来到放暑假了。虽说不用上课&#xff0c;学生没有了学习压力&#xff0c;但就以下三方面来看&#xff0c;学生还是要继续投放心力。 首先&am…

Unity 之 Addressable可寻址系统 -- 代码加载介绍 -- 进阶(一)

Unity 之 可寻址系统 -- 代码加载介绍 -- 进阶&#xff08;一&#xff09;一&#xff0c;可寻址系统代码加载1.1 回调形式1.2 异步等待1.3 面板赋值1.4 同步加载二&#xff0c;可寻址系统分标签加载2.1 场景搭建2.2 代码示例2.3 效果展示三&#xff0c;代码加载可寻址的解释概述…

Cadence OrCAD: 跨页符和电源符号命名优先级的一个小问题

Cadence OrCAD: 跨页符和电源符号命名优先级的一个小问题 遇到的问题 最近项目中&#xff0c;有个电源需要做负载端的反馈&#xff0c;类似下图的signal1和signal1N&#xff0c;反馈使用类似伪差分线&#xff0c;把电压信号和负载端的GND都连到DC-DC控制器。DC-DC对应的反馈引…

字节跳动青训营--前端day1

文章目录前言一、 前端1 前端的技术栈2. 前端的边界3. 前端的关注点二、 HTML1. HTML常用标签及语义化2. HTML 语法3. 谁在使用我们写的HTML前言 仅以此文章记录学习历程 一、 前端 解决GUI人机交互问题 1 前端的技术栈 2. 前端的边界 nodejs–服务器端应用 electron… --客…

【数据结构】6.1 图的基本概念和术语

文章目录前言6.1 图的定义和术语前言 图是一种比线性表和树更为复杂的数据结构。 在线性结构中&#xff0c;结点之间的关系属于一个对一个&#xff1b;数据元素之间有着线性关系&#xff0c;每个数据元素只有一个直接前趋和一个直接后继&#xff0c; 在树形结构中&#xff0c;…

算法设计与分析课程

算法的由来 算法的定义 算法的定义&#xff1a;给定计算问题&#xff0c;算法是一系列良定义的计算步骤&#xff0c;逐一执行计算步骤可得到预期的输出。 良定义&#xff1a;定义明确无歧义 计算步骤&#xff1a;计算机可以实现的指令 有了良定义的计算步骤&#xff0c;计算机就…

Java基础篇01-运算符的使用

01| Java中的数据类型 ) 1. 数值型&#xff1a; 序号类型空间占用说明最小值最大值默认值优缺点对比举例1byte8位有符号整数-128(-2^7)127 (2^7-1)0byte 类型用在大型数组中节约空间&#xff0c;主要代替整数&#xff0c;因为 byte 变量占用的空间只有 int 类型的四分之一by…