Pytorch模型自定义数据集训练流程

news2025/1/1 10:46:01

文章目录

  • Pytorch模型自定义数据集训练流程
    • 1、任务描述
    • 2、导入各种需要用到的包
    • 3、分割数据集
    • 4、将数据转成pytorch标准的DataLoader输入格式
    • 5、导入预训练模型,并修改分类层
    • 6、开始模型训练
    • 7、利用训好的模型做预测

Pytorch模型自定义数据集训练流程

我们以kaggle竞赛中的猫狗大战数据集为例搭建Pytorch自定义数据集模型训练的完整流程。

1、任务描述

Cats vs. Dogs(猫狗大战)数据集是Kaggle大数据竞赛某一年的一道赛题,利用给定的数据集,用算法实现猫和狗的识别。 其中包含了训练集和测试集,训练集中猫和狗的图片数量都是12500张且按顺序排序,测试集中猫和狗混合乱序图片一共12500张。

在这里插入图片描述

卷积神经网络(CNN)是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习的代表算法之一。卷积神经网络具有表征学习能力,能够按其阶层结构对输入信息进行平移不变分类,因此也被称为“平移不变人工神经网络”。不使用深度学习框架,用numpy基础代码来构建自己的深度学习网络。

2、导入各种需要用到的包

import torch
import torchvision
from torchvision import datasets, transforms
import torch.utils.data
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch import nn
import numpy as np
import os
import shutil
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

3、分割数据集

下载猫狗大战数据集,并解压。
下载地址:https://www.kaggle.com/c/dogs-vs-cats/data。
解压完成后,通过以下代码实现数据集预处理(剔除不能正常打开的图片,打乱数据集);然后对数据集进行分割,其中0.9的数据集作为train训练,0.1的数据集作为test测试。

# 分割数据集,将全部数据分成0.9的Train和0.1的Test
source_path = r"./kagglecatsanddogs_5340/PetImages/"
# 如果不存在文件夹要新建一个
if not os.path.exists(os.path.join(source_path, "train")):
    os.mkdir(os.path.join(source_path, "train"))
train_dir = os.path.join(source_path, "train")

if not os.path.exists(os.path.join(source_path, "test")):
    os.mkdir(os.path.join(source_path, "test"))
test_dir = os.path.join(source_path,"test")

## 将Cat和Dog文件夹全部移到train目录下,然后再从train目录下移动10%到test目录下
for category_dir in os.listdir(source_path):
    if category_dir not in ["train", "test"]:
        shutil.move(os.path.join(source_path,category_dir), os.path.join(source_path,"train"))
            
## 开始移动,移动前先剔除不能正常打开的图片
for dir in os.listdir(train_dir):
    category_dir_path = os.path.join(train_dir, dir)
    image_file_list = os.listdir(category_dir_path)   # 取出全部图片文件
    for file in image_file_list:
        try:
            Image.open(os.path.join(category_dir_path, file))
        except:
            os.remove(os.path.join(category_dir_path, file))
            image_file_list.remove(file)
    np.random.shuffle(image_file_list)
    test_num = int(0.1*len(image_file_list))
 
    #移动10%文件到对应目录
    if not os.path.exists(os.path.join(test_dir,dir)):
        os.mkdir(os.path.join(test_dir,dir))
    if len(os.listdir(os.path.join(test_dir,dir))) < test_num:  # 只有未移动过才需要移动,否则每运行一次都会移动一下
        for i in range(test_num):
            shutil.move(os.path.join(category_dir_path,image_file_list[i]), os.path.join(test_dir,dir,image_file_list[i]))

4、将数据转成pytorch标准的DataLoader输入格式

1、先对数据集进行预处理,包括resize成224*224的尺寸,因为vgg_net模型需要的输入尺寸为[N, 224, 224, 3];随机翻转,随机旋转等,另外对数据集做Normalize标准化,其中的mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.2]是从ImageNet数据集上的百万张图片中随机抽样计算得到的,以上这些内容主要是数据增强,增强模型的泛化性,有更好的预测效果。
2、然后将预处理好的数据转成pytorch标准的DataLoader输入格式,。

# 数据预处理
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),# 对图像进行随机的crop以后再resize成固定大小
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456,
0.406],std=[0.229, 0.224, 0.2]),  # ImageNet全部图片的平均值和标准差
    transforms.RandomRotation(20), # 随机旋转角度
    transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
])
 
# 读取数据
root = source_path
train_dataset = datasets.ImageFolder(root + '/train', transform)
test_dataset = datasets.ImageFolder(root + '/test', transform)
 
# 导入数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

5、导入预训练模型,并修改分类层

1、定义device,如果有GPU模型训练会自动用GPU训练,否则会使用CPU;使用GPU训练,只需在模型、数据、损失函数上使用cuda()就行。
2、这边默认对分类图像算法都熟悉,可以自己构建vgg16的完整网络,在猫狗数据集上重新训练。也可以下载预训练模型,由于原网络的分类输出是1000类别的,但是我们的图片只有两类,所以需要修改分类层,让模型能够适配我们的训练数据集。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
print(vgg16)

inputs = torch.rand(1, 3, 224, 224)  # 拿一个随机tensor测试一下网络的输出是否满足预期
output = vgg16(inputs.to(device))
print("原始VGG网络的输出:",output.size())

# 构建新的全连接层
vgg16.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2)).to(device)
inputs = torch.rand(1, 3, 224, 224)
output = vgg16(inputs.to(device))
print("新构建的VGG网络的输出:",output.size())

6、开始模型训练

开始模型训练,我们这里只训练全连接分类层,将特征层的梯度requires_grad设置为False,特征层的参数将不参与训练。
训练过程中保存效果最好的网络模型,以防掉线,可以从断点开始继续训练,同时也可以用来做预测。
训练完成后,保存训练好的网络和参数,后面可以加载模型做预测。

writer = SummaryWriter("./logs/model")
loss_func = nn.CrossEntropyLoss().to(device)
learning_rate = 0.0001

#如果我们想只训练模型的全连接层
for param in vgg16.features.parameters():
    param.requires_grad = False
optimizer = torch.optim.Adam(vgg16.parameters(),lr=learning_rate)

##训练开始
total_train_step = 0
total_test_step = 0
min_acc = 100.0
for epoch in range(10):
    print("-----------train epoch {} start---------------".format(epoch))
    vgg16.train()
    for data in train_loader:
        optimizer.zero_grad()
        img, label = data
        output = vgg16(img.to(device))
        loss = loss_func(output, label.to(device))
        loss.backward()
        optimizer.step()
        total_train_step += 1
        
        if total_train_step % 10 == 0:
            print("steps: {}, train_loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)


    ## 测试开始,看训练效果是否满足预期
    total_test_loss = 0
    total_acc = 0.0
    vgg16.eval()
    with torch.no_grad():
        for data in test_loader:
            optimizer.zero_grad()
            img, label = data
            output = vgg16(img.to(device))
            loss = loss_func(output, label.to(device))
            total_test_loss += loss
            accuary = torch.sum(output.argmax(1) == label.to(device))
            total_acc += accuary
    total_test_step += 1
    val_acc = total_acc.item() / len(test_dataset)
    
    ## 保存Acc最小的模型
    if val_acc < min_acc:
        min_acc = val_acc
        torch.save(vgg16.state_dict(), "./models/2classes_vgg16_weight_{}_{}.pth".format(epoch, round(val_acc,4)))
        torch.save(vgg16, "./models/2classes_vgg16_{}_{}.pth".format(epoch, round(val_acc,4)))

    print("测试loss: {}".format(total_test_loss.item()))
    print("测试Acc: {}".format(val_acc))
    writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)
    writer.add_scalar("test_Acc", val_acc, total_test_step)

torch.save(vgg16.state_dict(), "./models/2classes_vgg16_latest_{}.pth".format(val_acc))

7、利用训好的模型做预测

拿出一张图片做预测,首先导入预训练模型,同样改掉分类层,然后导入预训练权重,预测图片类别,输出标签值和预测类别。

import matplotlib.pyplot as plt
img_path = r"./kagglecatsanddogs_5340/PetImages/test/Cat/1381.jpg"   # 拿出要预测的图片
image = Image.open(img_path).convert("RGB")
image.show()
    
vgg16_pred = torchvision.models.vgg16(pretrained=True)
vgg16_pred.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2))

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224), interpolation=2),
    torchvision.transforms.ToTensor()
])
vgg16_pred.load_state_dict(torch.load("./models/2classes_vgg16_weight_15_0.9467513434294089.pth", map_location=torch.device('cpu')))
print(vgg16_pred)

image = transform(image)
print(image.size())
image = torch.reshape(image, [1,3,224,224])
vgg16_pred.eval()
with torch.no_grad():
    output = vgg16_pred(image)
# print("预测值为:",output)
print("预测标签为:",output.argmax(1).item())
print("预测动物为:",train_dataset.classes[output.argmax(1)])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/167481.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

交互与前端20 APIFunc.DataBase监控

说明 APIFunc.DataBase的第一版有一个监控一直在做agg,造成数据库的无谓消耗,所以一定得修补。在修补的同时,做了一些主要的修改: 1 【自增ID】给Mongo的In和Out增加了数据的自动编号和随机数生成。2 【使用缓存】通过Redis缓存,极大的的减轻了Mongo(主库)的负担这样,使得…

Kruskal重构树学习笔记(C++)

Kruskal重构树学习笔记 提示&#xff1a; 学习Kruskal重构树之前建议先了解一下Kruskal算法&#xff0c;虽然不了解这个影响不会很大 但一定要了解一下并查集的算法 接下来如果想要应用Kruskal重构树&#xff0c;一定要了解一下LCA算法 什么是Kruskal重构树 这里先简单说…

exec函数族详解

文章目录exec介绍exec族execl函数execlp函数execv函数exec介绍 通过命令查看帮助&#xff1a;man 3 exec exec 函数族的作用是根据指定的文件名找到可执行文件&#xff0c;并用它来取代调用进程的内容&#xff0c;换句话说&#xff0c;就是在调用进程内部执行一个可执行文件。…

基于多线程版本的定时器

定时器 1)咱们前面学习过的阻塞队列&#xff0c;相比于普通的队列线程安全&#xff0c;相比于普通的队列起到一个更好的阻塞效果 2)虽然使用阻塞队列&#xff0c;可以达到销峰填谷这样的一个效果&#xff0c;但是峰值中有大量的数据涌入到队列中&#xff0c;如果后续的服务器消…

教程:Flutter 和 Rust混合编程,使用flutter_rust_bridge自动生成ffi代码

实践环境&#xff1a;Arch Linuxflutter_rust_bridge官方文档Flutter环境配置教程 | Rust环境配置教程记录使用flutter_rust_bridge遇到的一些坑。假设已经我们配置了Fluuter与Rust环境现在直接使用flutter_rust_bridge模板创建自己的项目运行&#xff1a;git clone https://gi…

W13Scan 扫描器挖掘漏洞实践

一、背景 这段时间总想捣鼓扫描器&#xff0c;发现自己的一些想法很多前辈已经做了东西&#xff0c;让我有点小沮丧同时也有点小兴奋&#xff0c;说明思路是对的&#xff0c;我准备站在巨人的肩膀去二次开发&#xff0c;加入一些自己的想法&#xff0c;从freebuf中看到W13Scan…

进程调度模块

目录 1.进程介绍 2.进程调度 2.1.进程状态 2.2.进程调度函数 ---schedule 2.3.进程切换函数 ---switch_to&#xff08;&#xff09; 1.进程介绍 在进程模块里面&#xff0c;我们知道了进程就是一个task_struct的结构体&#xff0c;里面含有进程的各种信息。进程存放在进程…

AppScan被动手动探索扫描

系列文章 AppScan介绍和安装 AppScan 扫描web应用程序 第三节-AppScan被动手动探索扫描 被动式扫描&#xff1a;浏览器代理到AppScan&#xff0c;然后进行手工操作&#xff0c;探索产生出的流量给AppScan进行扫描。 他的优点是&#xff1a;扫描足够精准&#xff0c;覆盖率更…

注册中心和负载均衡(黑马SpringCloud笔记)

注册中心和负载均衡 目录注册中心和负载均衡一、服务远程调用1. RestTemplate2. 服务调用关系3. 远程调用的问题二、注册中心1. Eureka注册中心1.1 搭建Eureka注册中心1.2 服务注册1.3 服务拉取1.4 小结2. nacos注册中心2.1Nacos搭建2.2 服务注册2.3 服务拉取2.4 服务分级存储模…

虹科新闻 | 虹科与丹麦Eupry正式建立合作伙伴关系

近期&#xff0c;虹科与丹麦Eupry正式建立合作伙伴关系。未来&#xff0c;虹科与Eupry将共同关注最具创新性和稳定性的解决方案&#xff0c;为客户提供温度记录仪、温湿度记录仪、Mapping温度分布验证服务、以及基于云的温湿度自动监测系统。 虹科非常高兴欢迎并宣布我们的新合…

【Linux】基础:进程信号

【Linux】基础&#xff1a;进程信号 摘要&#xff1a;本文将会从生活实际出发&#xff0c;由此掌握进程信号的学习过程&#xff0c;分别为信号的产生、信号的传输、信号的保存和信号的处理&#xff0c;最后再补充学习信号后方便理解的其他概念。 文章目录【Linux】基础&#xf…

echarts柱状图值为0时不显示以及柱状图百分比展示

echarts柱状图值为0时不显示以及柱状图百分比展示 1.效果展示 2.代码 <template><div id"container"><div id"main"></div></div> </template> <script>import * as echarts from echarts import * as lodash…

(JVM)浅堆深堆与内存泄露

​浅堆深堆与内存泄露 1. 浅堆&#xff08;Shallow Heap&#xff09; 浅堆是指一个对象所消耗的内存。在 32 位系统中&#xff0c;一个对象引用会占据 4 个字节&#xff0c;一个 int 类型会占据 4 个字节&#xff0c;long 型变量会占据 8 个字节&#xff0c;每个对象头需要占用…

01.【Vue】Vue2基础操作

一、Vue Vue (读音 /vjuː/&#xff0c;类似于 view) 是一套用于构建用户界面的渐进式框架。与其它大型框架不同的是&#xff0c;Vue 被设计为可以自底向上逐层应用。Vue 的核心库只关注视图层&#xff0c;不仅易于上手&#xff0c;还便于与第三方库或既有项目整合。另一方面&…

十五天学会Autodesk Inventor,看完这一系列就够了(七),工程图纸

众所周知&#xff0c;Autocad是一款用于二维绘图、详细绘制、设计文档和基本三维设计&#xff0c;现已经成为国际上广为流行的绘图工具。Autodesk Inventor软件也是美国AutoDesk公司推出的三维可视化实体模拟软件。因为很多人都熟悉Autocad&#xff0c;所以再学习Inventor&…

自动化测试 | 这些常用测试平台,你们公司在用的是哪些呢?

本文节选自霍格沃兹测试学院内部教材 测试管理平台是贯穿测试整个生命周期的工具集合&#xff0c;它主要解决的是测试过程中团队协作的问题。在整个测试过程中&#xff0c;需要对测试用例、Bug、代码、持续集成等等进行管理。下面分别从这四个方面介绍现在比较流行的管理平台。…

Spring入门-SpringAOP详解

文章目录SpringAOP详解1&#xff0c;AOP简介1.1 什么是AOP?1.2 AOP作用1.3 AOP核心概念2&#xff0c;AOP入门案例2.1 需求分析2.2 思路分析2.3 环境准备2.4 AOP实现步骤步骤1:添加依赖步骤2:定义接口与实现类步骤3:定义通知类和通知步骤4:定义切入点步骤5:制作切面步骤6:将通知…

Anaconda+VSCode配置tensorflow

主要参考https://blog.csdn.net/qq_42754919/article/details/106121979vscode的安装以及Anaconda的安装网上有很多教程&#xff0c;大家可以自行百度就行。在安装Anaconda的时候忘记勾选自动添加path&#xff0c;需要手动添加环境变量path下面介绍tensorflow安装教程:1.打开An…

getRequestDispatcher()转发和sendRedirect()重定向介绍与比较

文章目录1. request.getRequestDispatcher()1.1请求转发和请求包含的区别1.2request域2.response.sendRedirect()3.请求转发与重定向的区别比较测试1. request.getRequestDispatcher() getRequestDispatcher()包含两个重要方法&#xff0c;分别是请求转发和请求包含。一个请求…

系分 - 案例分析 - 系统设计

个人总结&#xff0c;仅供参考&#xff0c;欢迎加好友一起讨论 文章目录系分 - 案例分析 - 系统设计结构化设计SD内聚偶然内聚逻辑内聚时间&#xff08;瞬时&#xff09;内聚过程内聚通信内聚顺序内聚功能内聚耦合内容耦合公共耦合外部耦合控制耦合标记耦合数据耦合非直接耦合补…