【项目实践】猫十二分类

news2024/12/24 22:08:29

【数据科学项目实践】基于ResNet和Inception v3的猫十二分类迁移学习

一、项目背景

本项目来源于飞浆平台的图像分类学习赛。指路链接

  • 代码和结果来源于我的小组同学,没有做任何的改动,我这边仅做一个总结归纳,以便学习和复盘

简单把赛题Copy一下:

本场比赛要求参赛选手对十二种猫进行分类,属于CV方向经典的图像分类任务。图像分类任务作为其他图像任务的基石,可以让大家更快上手计算机视觉。

数据集

比赛数据集包含12种猫的图片,并划分为训练集与测试集。

训练集: 提供高清彩色图片以及图片所属的分类,共有2160张猫的图片,含标注文件。

测试集: 仅提供彩色图片,共有240张猫的图片,不含标注文件。

二、Baseline

2.1 准备阶段

主要是导入一些要用到的模块:

import os
import cv2
import torch
import torch.nn as nn
from torchvision import models,transforms
from torch.utils.data import DataLoader,Dataset
import numpy as np
from PIL import Image
from torch.optim import lr_scheduler
import copy

2.2 数据读取阶段

这个阶段就是如何将数据读取到模型中来,由于猫猫是图像数据,所以这边将其读取成数字图像一般是通过数组来存在内存中的,考虑到中间过程的可视化,我们通过PIL来读取Image类型的数据。这步可以写作:

x=np.fromfile(imgPath,dtype=np.float32) # 读取成ndarray
x=cv2.imdecode(x,1) # 将区间转化为[0,255]
img=PIL.Image.fromarray(x) # 读取成Image对象

在这里插入图片描述

上图中,左边的是Image类型的数据,右边是cv读取的数据,可以发现发生了颜色通道的调换。实际上,读取到cv这部分就好了,可以调用多窗口的imshow进行数据可视化。

我们现在拿到了猫猫图像!那么接下来就要拿到猫猫的标签啦,一般情况下,我们会将数据跟标签记录在一个文档里,每一行对应一个数据(图片)路径和一个标签:

# 文件标签
filelist=r"data_split_list.txt"
imgs,labels=[],[] # 存储列表

with open(filelist) as f:
    lines=[_.strip() for _ in f] # 去除空白
    np.random.shuffle(lines) # 随机打乱
    for l in lines:
        img_path,label=l.split('\t') # 获取图片路径和标签
        img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,np.float32),1))
        imgs.append(img)
        labels.append(label)

我们将这部分工作封装成一个函数,就可以实现数据的读取了。

接下来的工作,就是将数据转化为PyTorch接受的格式啦。众所周知,PyTorch的模型训练跟推理一般是通过迭代一个DataLoader对象来进行的,而DataLoader对象的数据集是一个DataSet类。所以这里我们需要构建一个Dataset类啦:

class myData(Dataset):
    
    def __init__(self):
        super(myData,self).__init__()
        self.data=[]
    
    def __getitem__(self,x):
        return self.data[x]
    
    def __len__(self):
        return len(self.data)

嗯,把上面三个函数填完就阔以啦。

对于图像数据,我们需要应用一个transforms,这里做最简单的变换:转为Tensor,尺寸裁剪,标准化

self.transform=transforms.Compose(
    transforms.ToTensor(),
    transforms.Resize((299,299)),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
)

最终的Dataset如下:

class myData(Dataset):

    def __init__(self,kind):
        super(myData, self).__init__()
        self.mode=kind
        self.transform=transforms.Compose(
            transforms.ToTensor(),
            transforms.Resize((299,299)),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        )

        if kind=="test":
            self.imgs=self.load_origin_data()
        else:
            self.imgs,self.labels=self.load_origin_data()

    def __getitem__(self, item):
        if self.mode=="test":
            return self.transform(self.imgs[item])
        else:
            return self.transform(self.imgs[item]),torch.tensor(self.labels[item])

    def __len__(self):
        return len(self.imgs)

    def load_origin_data(self):
        filelist = './data/%s_split_list.txt' % self.mode
        imgs,labels=[],[]
        data_dir=os.getcwd()+"/data"
        if self.mode=='train' or self.mode=='val':
            with open(filelist) as f:
                lines=[_.strip() for _ in f]
                if self.mode=='train':
                    np.random.shuffle(lines)
                    for l in lines:
                        img_path,label=l.split('\t')
                        img_path=os.path.join(data_dir,img_path)
                        try:
                            img=Image.fromarray(cv2.imdecode(np.fromfile(img_path,dtype=np.float32),1))
                            imgs.append(img)
                            labels.append(int(label))
                        except Exception("The path %s"%img_path+" may be wrong") as e:
                            print(e)
                            continue
                    return imgs,labels
                elif self.mode=="test":
                    full_lines = os.listdir('data/cat_12_test/')
                    lines = [line.strip() for line in full_lines]
                    for img_path in lines:
                        img_path = os.path.join(data_dir, "cat_12_test/", img_path)
                        img = Image.open(img_path)
                        imgs.append(img)
                    return imgs

2.3 模型训练

我们刚刚说PyTorch的模型训练跟推理一般是通过迭代一个DataLoader对象来进行的,现在就是需要构建这个东西啦:

def get_Dataloader():
    img_datasets = {x: myData(x) for x in ['train', 'val', 'test']}
    dataset_sizes = {x: len(img_datasets[x]) for x in ['train', 'val', 'test']}

    train_loader = DataLoader(
        dataset=img_datasets['train'],
        batch_size=24,
        shuffle=True
    )

    val_loader = DataLoader(
        dataset=img_datasets['val'],
        batch_size=1,
        shuffle=False
    )

    test_loader = DataLoader(
        dataset=img_datasets['test'],
        batch_size=1,
        shuffle=False
    )

    dataloaders = {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }
    return dataset_sizes,dataloaders

接下来就是单纯的训练过程了。步骤总结如下:

  • 参数设置阶段
    • 设置GPU
    • 设置优化器、损失函数、学习策略
  • 训练过程
    • 迭代DataLoader
    • 优化器梯度清零
    • 模型推理
    • 误差计算
    • 反向传播
    • 更新优化器、学习率
  • 模型评估
    • 计算每轮的误差累计值、精度
    • 选择最优精度并进行模型保存
def Train(model,criterion,optimizer,scheduler,num_epoches=25):
    dataset_sizes,dataloaders=get_Dataloader()
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0

    for epoch in range(num_epoches):
        print("Epoch {}/{}".format(epoch+1,num_epoches))

        for phase in ['train','val']:
            if phase=="train":
                model.train()
            else:
                model.eval()

            trian_loss=0.0
            train_corrects=0

            for inputs,labels in dataloaders[phase]:
                inputs,labels=inputs.to(device),labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase=="train"):
                    # 上下文管理器,参数是Bool,用于确定是否对Block内的语句进行求导
                    y_pre=model(inputs)
                    _,y_pre=torch.max(y_pre,1)
                    loss=criterion(y_pre,labels)

                    if phase=="train":
                        loss.backward()
                        optimizer.step()

                trian_loss+=loss.item()*inputs.size(0)
                train_corrects+=torch.sum(y_pre==labels)
            if phase=="train":
                scheduler.step()

            epoch_loss=trian_loss/dataset_sizes[phase]
            epoch_acc=train_corrects.float()/dataset_sizes[phase]

            print("{} Loss :{:.4f} Acc {:.4}".format(phase,epoch_loss,epoch_acc))

            if phase=="val" and epoch_acc>best_acc:
                best_acc=epoch_acc
                best_model_wts=copy.deepcopy(model.state_dict())
    print("Best val Acc : {:4f}".format(best_acc))
    model.load_state_dict(best_model_wts)
    return model

三、迁移学习

迁移学习(Transfer Learning)就是利用预训练好的大模型参数去学习其他数据的分布。

这个过程我们一般不希望原始模型参数改变,因而一般需要做如下工作:

for param in model.parameters():
    param.requires_grad=False

然后,我们需要构架最后一层全连接层,用来学习新的数据集:

model.fc=nn.Linear(2048,num_classes)

也就是最后需要训练的就是这个全连接层了。

def Inception(device):
    # 用训练好的模型进行迁移
    model_ft=models.inception_v3(pretrained=True)
    # model_ft=models.resnet50(pretrained=True)
    # model_ft=models.alexnet(pretrained=True)

    num_ftrs=model_ft.fc.in_features
    model_ft.fc=nn.Linear(num_ftrs,12) # 设置全连接层最终结果
    
    model_ft=model_ft.to(device)

    cirterion=nn.CrossEntropyLoss()
    optimizer_ft=torch.optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9)
    exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=5,gamma=0.1)
    model_ft=Train(model_ft,cirterion,optimizer_ft,exp_lr_scheduler,num_epoches=30)

四、结果分析

  • Inception

    Epoch 30/30
    train Loss: 0.1065 Acc: 0.9858
    val Loss: 0.3026 Acc: 0.8983
    Best val Acc: 0.918336
    
  • AlexNet

    Epoch 30/30
    train Loss: 0.1403 Acc: 0.9601
    val Loss: 0.6815 Acc: 0.7750
    Best val Acc: 0.779661
    
  • ResNet50

    Epoch 30/30
    train Loss: 0.0480 Acc: 0.9973
    val Loss: 0.3157 Acc: 0.9060
    Best val Acc: 0.909091
    

中间部分特征图的结果如下:

在这里插入图片描述

特征图嘛,主打的就是一个抽象。可以发现同一张图经过不同的卷积核作用后,有了全新的高维特征,这些特征也主打的就是一个难以解释,反正就看个乐。

在这里插入图片描述

基本上7个epoch就收敛了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/570316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

chatgpt赋能python:Python句柄操作

Python 句柄操作 Python作为一门强大又易于使用的编程语言,其在内存管理和资源分配方面广受欢迎。其中Python句柄操作是其独特之处之一。在本文中,我们将介绍Python句柄操作的概念,方法,以及模块,并讨论Python中句柄操…

统计学中的t检验 、f检验、卡方检验

1.1数据的种类 我们都知道,一般数据可以分为两类,即定量数据(数值型数据)和定性数据(非数值型数据),定性数据很好理解,例如人的性别,姓名这些都是定性数据。 定量数据可…

CSS3煎制荷包蛋动画特效,优质男士表白必备

你有多久没吃过早餐了?你是否每天忙碌到很晚,结果导致早上起来也很晚,匆匆忙忙来不及吃早餐,更别说自己做了。一直到现在,你有多久没有吃到过母亲做的早饭了?我们在外奔波,希望家人安康&#xf…

【C语言】几种方法解决问题:C6031返回值被忽略:“scanf” (保姆级图文)

目录 错因分析1. 使用_s结尾的安全函数版本(推荐)2. 在本项目中关闭警告(作用一个项目)3. 在本文件中关闭警告(作用一个文件)总结 欢迎关注 『C语言』 系列,持续更新中 欢迎关注 『C语言』 系列…

分布式协调服务--zookeeper

目录 一、概述 1、zookeeper有两种运行状态 zookeeper架构的角色: 2、Paxos算法:消息传递的一致性算法 3、ZAB协议 Zab 协议实现的作用 Zab协议核心 Zab协议内容 消息广播 崩溃恢复 实现原理 协议实现 一、概述 zookeeper官网 zookeeper官…

Trace32使用Data.Test和Data.TestList命令测试内存类型以及完整性

我们在debug的时候,可以使用Trace32自带的一些命令快速地检测目标系统的内存的类型和完整性(是否可读或可写),以便快速排除内存缺陷带来的干扰。 目录 Data.Test: 内存完整性测试 Memory integrity test Data.TestL…

Android进阶 View事件体系(二):从源码解析View的事件分发

Android进阶 View事件体系(二):从源码解析View的事件分发 内容概要 本篇文章为总结View事件体系的第二篇文章,前一篇文章的在这里:Android进阶 View事件体系(一):概要介绍和实现Vie…

chatgpt赋能python:Python动态增加成员变量简介

Python动态增加成员变量简介 Python是著名的解释型编程语言,在众多开源项目中得到了广泛的应用。它以简洁明了的语法和高效的运行速度而闻名,成为了许多开发者的首选。 Python提供了极大的灵活性,使得我们可以随意添加、修改和删除对象的属…

chatgpt赋能python:Python切割技巧:如何用Python切割字符串和列表

Python切割技巧:如何用Python切割字符串和列表 Python是一种高级编程语言,被广泛用于数据分析、机器学习、Web应用程序等领域。在Python编程中,切割技巧是一项必备技能。 什么是切割技巧? 切割技巧是指用一种编程语言&#xff…

chatgpt赋能python:Python列表倒序-从入门到实践

Python列表倒序 - 从入门到实践 Python是一种高级编程语言,被广泛运用于web开发、科学计算、数据分析等领域,也是初学者学习的首选语言之一。Python的列表(List)是其中一个常用的数据类型。在本文中,我们将深入探讨Python列表倒序的方法&…

chatgpt赋能python:Python列表反向:如何用简单的代码将列表元素反转

Python列表反向:如何用简单的代码将列表元素反转 在很多编程语言中,将列表元素反转是一项常见的任务。Python也不例外。Python内置函数提供了一种非常直接的方式来将列表元素反转,而不需要费力地创建一个新列表。 什么是列表反向&#xff1…

chatgpt赋能python:Python动态代码的SEO优化技巧

Python 动态代码的SEO优化技巧 Python是一种常用的编程语言,它以简化开发流程和易于阅读的代码著称。Python动态代码能够让开发者更快捷方便地进行编码,并且能够改善SEO表现。在本文中,我们将着重介绍Python动态代码与SEO优化涉及的技巧。 …

chatgpt赋能python:Python分组匹配:了解正则表达式中的分组匹配技巧

Python 分组匹配: 了解正则表达式中的分组匹配技巧 在 Python 中,正则表达式是一种重要的文本处理工具,它可以帮助我们在字符串中匹配、查找和替换特定的文本模式。其中,分组匹配是正则表达式的重要特性之一,它可以将匹配的结果按…

快速理解会话跟踪技术Cookie和Session

文章目录 会话跟踪技术客户端会话跟踪技术Cookie服务端会话跟踪技术Session 会话跟踪技术 会话:客服端和服务端的多次请求与响应称为会话。 会话跟踪:服务器需要识别多次请求是否来自同一浏览器,在同一次会话多次请求中共享数据。 HTTP协议是…

chatgpt赋能python:Python加解密算法简介

Python加解密算法简介 在当今数字化的时代,数据的安全性变得至关重要。而加密算法就成为了保障数据安全的重要手段之一。Python作为一门高级编程语言,提供了许多加密算法库,使得开发人员可以轻松地实现加密功能。本文将着重介绍Python中一些…

机器学习模型——回归模型

文章目录 监督学习——回归模型线性回归模型最小二乘法求解线性回归代码实现引入依赖:导入数据:定义损失函数:定义核心算法拟合函数:测试:画出拟合曲线: 多元线性回归梯度下降求线性回归梯度下降和最小二乘…

chatgpt赋能python:Python中%取模操作的介绍

Python中%取模操作的介绍 在Python中,取模操作使用符号“%”表示,它的作用是取两个数相除的余数。例如,10 % 3等于1,因为10除以3的余数为1。这个操作可以用在很多场合,比如判断一个数是奇数还是偶数,或者判…

带你开发一个远程控制项目---->STM32+标准库+阿里云平台+传感器模块+远程显示。

目录 本次实验项目: 下次实验项目: 本次项目视频结果/APP/实物展示 实物展示 APP展示 视频展示 模块选择说明; 温湿度传感器模块介绍 光照传感器介绍 ESP8266-01S模块介绍 本次实验项目: 项目清单平台单片机语言实现温湿度传感器模…

Reinforcement Learning | 强化学习十种应用场景及新手学习入门教程

文章目录 1.在自动驾驶汽车中的应用2.强化学习的行业自动化3.强化学习在贸易和金融中的应用4.NLP(自然语言处理)中的强化学习5.强化学习在医疗保健中的应用6.强化学习在工程中的应用7.新闻推荐中的强化学习8.游戏中的强化学习9.实时出价——强化学习在营…

Redis中的Reactor模型源码探索

文章目录 摘要了解Linux的epoll了解Reactor模型 源码initServerinitListenersaeMain 事件管理器aeProcessEvents读事件 摘要 有时候在面试的时候会被问到Redis为什么那么快?有一点就是客户端请求和应答是基于I/O多路复用(比如linux的epoll)的…