pytorch-训练自定义数据集实战

news2025/1/10 12:05:50

目录

  • 1. 步骤
  • 2. 加载数据
    • 2.1 继承Dataset
      • 2.1.1 生成name2label
      • 2.1.2 生成image path, label的文件
      • 2.1.3 __len__
      • 2.1.3 __getitem__
      • 2.1.4 数据切分为train、val、test
  • 3. 建立模型
  • 4. 训练和测试
  • 4. 完整代码

1. 步骤

  • 加载数据
  • 创建模型
  • 训练和测试
  • 迁移学习

2. 加载数据

这里以宝可梦动画图片为数据集
在这里插入图片描述
下载地址:
链接:https://pan.baidu.com/s/1TbXKNIBitXk_o-oVAAiX-A?pwd=py3r
提取码:py3r

数据集各分类情况和切分比例见下图:
在这里插入图片描述

2.1 继承Dataset

继承torch.utils.data.Dataset,实现__len__和__getitem__函数
__len__是获取所有数据集的数量
__getitem__获取数据集中指定index的image tensor和对应的分类label
实现这两个函数的思路:

  • 将数据集所有文件名加载到list中,通过len([images]),即可实现__len__
  • 生成数据集中所有的image path和label,读取并预处理image,即可实现__getitem__

2.1.1 生成name2label

数据集文件结构是pokemon\bulbasaur\00000000.png,pokemon下的每个文件夹代表一个分类,因此就可以实现下面的代码生成一个name2label

 self.name2label = {} # "sq...":0
 for name in sorted(os.listdir(os.path.join(root))):
     if not os.path.isdir(os.path.join(root, name)):
         continue

     self.name2label[name] = len(self.name2label.keys())

2.1.2 生成image path, label的文件

获取pokemon目前下所有数据文件的路径放到images中,遍历images,通过每条数据文件路径中的分类文件夹名称从name2label获取到对应的label,然后写入到文件中。
代码如下:

       if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                # 'pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images: # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

2.1.3 len

    def __len__(self):

        return len(self.images)

2.1.3 getitem

预处理包括resize、randomRotation、ToTensor、Normalize等

def __getitem__(self, idx):
    # idx~[0~len(images)]
      # self.images, self.labels
      # img: 'pokemon\\bulbasaur\\00000000.png'
      # label: 0
      img, label = self.images[idx], self.labels[idx]

      tf = transforms.Compose([
          lambda x:Image.open(x).convert('RGB'), # string path= > image data
          transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
          transforms.RandomRotation(15),
          transforms.CenterCrop(self.resize),
          transforms.ToTensor(),
          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
      ])

      img = tf(img)
      label = torch.tensor(label)


      return img, label

2.1.4 数据切分为train、val、test

数据切分比例6:2:2

class Pokemon(Dataset):

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize

        self.name2label = {} # "sq...":0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        # print(self.name2label)

        # image, label
        self.images, self.labels = self.load_csv('images.csv')

        if mode=='train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val': # 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

3. 建立模型

使用前面实现的Resnet18https://editor.csdn.net/md/?articleId=140032483

4. 训练和测试

  • 数据加载
  • 实例化模型、优化器、loss函数
  • epoch循环、模型数据输入、计算loss、backward、优化器迭代
  • validation模型、保存模型
  • test模型
    代码:
def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def ResNet18():
    return ResNet(ResBlk, [2, 2, 2, 2], 10)
    
def main():

    model = ResNet18().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()


    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step, (x,y) in enumerate(train_loader):

            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)
            
            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc> best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')


    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)

4. 完整代码

train.py

import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoader

from    pokemon import Pokemon
from    resnet import ResNet18



batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)


train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                          num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


viz = visdom.Visdom()

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def main():

    model = ResNet18().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()


    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step, (x,y) in enumerate(train_loader):

            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)
            
            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc> best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')


    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)





if __name__ == '__main__':
    main()

pokemon.py

import  torch
import  os, glob
import  random, csv

from    torch.utils.data import Dataset, DataLoader

from    torchvision import transforms
from    PIL import Image


class Pokemon(Dataset):

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize

        self.name2label = {} # "sq...":0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2label[name] = len(self.name2label.keys())

        # print(self.name2label)

        # image, label
        self.images, self.labels = self.load_csv('images.csv')

        if mode=='train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val': # 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]





    def load_csv(self, filename):

        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                # 'pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            # 1167, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images), images)

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images: # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # read from csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)

        return images, labels



    def __len__(self):

        return len(self.images)


    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean

        return x


    def __getitem__(self, idx):
        # idx~[0~len(images)]
        # self.images, self.labels
        # img: 'pokemon\\bulbasaur\\00000000.png'
        # label: 0
        img, label = self.images[idx], self.labels[idx]

        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path= > image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)


        return img, label





def main():

    import  visdom
    import  time
    import  torchvision

    viz = visdom.Visdom()

    # tf = transforms.Compose([
    #                 transforms.Resize((64,64)),
    #                 transforms.ToTensor(),
    # ])
    # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
    # loader = DataLoader(db, batch_size=32, shuffle=True)
    #
    # print(db.class_to_idx)
    #
    # for x,y in loader:
    #     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    #
    #     time.sleep(10)


    db = Pokemon('pokemon', 64, 'train')

    x,y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))

    loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)

    for x,y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1952932.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

打造创新项目:从理念到市场的成功之路

打造创新项目:从理念到市场的成功之路 前言为何创新?如何创新?创新的意义 一、深入市场,洞察行业脉搏二、精准定位,锁定目标市场三、全面评估,确保项目可行性四、创新引领,打造独特卖点五、开放…

二叉树_堆(下卷)

前言 接前面两篇的内容,接着往下讲二叉树_堆相关的内容。 正文 那么,回到冒泡排序与堆排序的比较。 我们知道冒泡排序的时间复杂度为 O ( N 2 ) O(N^2) O(N2),这个效率是不太好的。 那么,我们的堆排序的时间复杂度如何呢&…

Linux:进程概述(什么是进程、进程控制块PCB、并发与并行、进程的状态、进程的相关命令)

进程概述 (1)What(什么是进程) 程序:磁盘上的可执行文件,它占用磁盘、是一个静态概念 进程:程序执行之后的状态,占用CPU和内存,是一个动态概念;每一个进程都有一个对应的进程控制块…

云计算复习--分布式存储系统

分布式存储 分布式存储系统是一种将数据分散在多个独立节点上,并通过网络进行数据传输和访问的存储系统 分布式存储的特点:可扩展性、高可用性、容错性、高性能等。分布式存储系统能够水平扩展存储容量和性能,提供持续可用的数据存储服务&am…

【在开发小程序的时候如何排查问题】

在开发小程序的时候如何排查问题 在最近开发小程序的时候,经常出现本地在浏览器中调试没有问题,但是一发布到预发环境就出现各种个样的问题 手机兼用性问题 有时候会出现苹果🍎手机键盘弹出,导致ui界面高度出现异常边界问题&#…

for循环打印1~10之间数字

对于for循环之前了解不够的同学可以看之前的我写的介绍 我们这里直接上代码 #include<stdio.h> int main() {int i 0;for (i 1; i < 11; i){printf("%d\n", i);}return 0; }

0724_驱动1 字符设备驱动内部实现

一、字符设备驱动内部实现工作原理 二、分布实现字符设备驱动API接口 分配对象&#xff1a; #include <linux/cdev.h> struct cdev *cdev_alloc(void) 函数功能&#xff1a;分配对象struct cdev *结构体指针 参数&#xff1a;无 返回值&#xff1a;成功返回struct cdev *…

人工智能类——计算机科学与技术

计算机科学与技术是一个非常大的门类。目前计算机科学与技术类招生的专业主要有计算机科学与技术、软件工程、网络工程、信息安全、物联网工程等&#xff0c;后面的几个专业是计算机科学与技术的重要分支&#xff0c;而这个门类的其他分支并没有单列出来一个本科专业&#xff0…

实战|EDU挖掘记录-某学校sql注入挖掘记录

本文来源无问社区&#xff0c;更多实战内容&#xff0c;渗透思路尽在无问社区http://www.wwlib.cn/index.php/artread/artid/9755.html 某大学的办公系统&#xff0c;学号是我从官网下载的优秀人员名单找到的&#xff0c;初始密码为姓名首字母加身份证后六位&#xff0c;我是社…

高级及架构师高频面试题-基础型

1、设计模式有哪些原则&#xff08;待解释的更直白&#xff09; 单一职责原则&#xff1a;一个类或方法应只负责一项职责&#xff0c;避免一个类因为多个变化原因而改变。开闭原则&#xff1a;软件实体应对扩展开放&#xff0c;对修改封闭。比如要增加用户类别的时候可以新增一…

Java高频面试题分享

文章目录 1. 策略模式怎么控制策略的选取1.1 追问&#xff1a;如果有100种策略呢&#xff1f;1.2 追问&#xff1a;什么情况下初始化Map 2. 什么是索引&#xff1f;什么时候用索引&#xff1f;2.1 追问&#xff1a;怎么判断系统什么时候用量比较少2.2 追问&#xff1a;如何实时…

树 ----- 基础学习

树 树&#xff1a;n&#xff08;n>0&#xff09;个结点的有限集合。n 0 ,空树。 在任意一个非空树中&#xff0c; 1&#xff0c;有且仅有一个特定的根结点 2&#xff0c;当n>1 时&#xff0c;其余结点可分为m个互不相交的有限集合T1,T2,T3.。。。。Tm&#xff0c;其中每…

使用 uPlot 在 Vue 中创建交互式图表

本文由ScriptEcho平台提供技术支持 项目地址&#xff1a;传送门 使用 uPlot 在 Vue 中创建交互式图表 应用场景介绍 uPlot 是一个轻量级、高性能的图表库&#xff0c;适用于创建各种交互式图表。它具有丰富的功能&#xff0c;包括可自定义的轴、网格、刻度和交互性。本篇博…

脊髓损伤的小伙伴锻炼贴士

Hey小伙伴们~&#x1f44b; 今天要跟大家聊一个超燃又超温馨的话题&#xff01;&#x1f31f; 对于我们脊髓损伤的小伙伴们来说&#xff0c;保持身体活力&#xff0c;不仅是健康的小秘诀&#xff0c;更是拥抱美好生活的超能量哦&#xff01;&#x1f4aa; #脊髓损伤# 首先&…

【ffmpeg命令入门】Nginx的安装与制作HLS流媒体服务器

文章目录 前言Nginx简介Ubuntu安装Nginxffmpeg生成HLS流媒体1. 生成HLS流媒体命令说明 配置Nginxffplay播放m3u8 总结 前言 在数字内容传输和流媒体服务中&#xff0c;HLS&#xff08;HTTP Live Streaming&#xff09;已经成为一种流行的解决方案&#xff0c;特别是在视频直播…

FPGA FIFO IP核(2)- 配置与调用

前言 上上期介绍了FIFO IP核理论方面的一些内容&#xff0c;接下来开始进行FIFO IP核的配置和使用部分。 FIFO IP核再理解 关键点 先进先出&#xff1a;数据按顺序写入FIFO&#xff0c;先被写入的数据在读取的时候先被读出。 FIFO存储器没有地址线。 FIFO主要作为缓存&#…

C语言 | Leetcode C语言题解之第275题H指数II

题目&#xff1a; 题解&#xff1a; int hIndex(int* citations, int citationsSize) {int left 0, right citationsSize - 1;while (left < right) {int mid left (right - left) / 2;if (citations[mid] > citationsSize - mid) {right mid - 1;} else {left mi…

时效性知识点是否值得花时间学习和研究

新趋势 智能大模型训练成本与人才培养成本之间的博弈。 视频 录了个断断续续的视频&#xff1a; 编程简单吗&#xff1f;为什么技术型内容几乎停更了&#xff1f; 代码形式的程序 /** Created by ArduinoGetStarted.com** This example code is in the public domain** Tuto…

[算法]插入排序和希尔排序

这里简单的介绍一下插入排序和希尔排序的算法实现&#xff0c;为简单起见&#xff0c;排序为升序且排序的数组是整形数组。 一、插入排序 &#xff08;一&#xff09;、算法思路 把数组里的第一个元素视为有序的&#xff0c;然后取第二个元素与前面的元素作比较&#xff0c;如…

2024钉钉杯A题思路详解

文章目录 一、问题一1.1 问题1.2 模型1.3 目标1.4 思路1.4.1 样本探究1.4.2 数据集特性探究&#xff1a;1.4.3 数据预处理1.4.4 数据趋势可视化1.4.5 ARIMA和LSTM两种预测模型1.4.6 参数调整 二、问题二2.1 问题2.2 模型2.3 目标2.4 思路2.4.1 样本探究2.4.2 数据集特性探究2.4…