使用 pytorch训练自己的图片分类模型

news2025/1/13 7:32:55

如何自己训练一个图片分类模型,如果一切从头开始,对于一般公司或个人基本是难以实现的。其实,我们可以利用一个现有的图片分类模型,加上新的分类,这种方式叫做迁移学习,就是把现有的模式知识,转移到新的模型。Pytorch 官网提供已经训练好的模型,可以在此基础上训练自己的模型。我们用的模型是 VGG 分类模型,首先,先运行一个已经训练好的模型可做 1000 个分类。

安装依赖

# 去官网根据系统进行下载
pip3 install torch torchvision torchaudio
pip3 install tqdm

现有模型进行图片识别

可以去百度上下载一个狗或者鸟的图片,运行下面的程序进行识别。

# 导入软件包
import numpy as np
import json
from PIL import Image

import torch
import torchvision
from torchvision import models, transforms

#生成VGG-16模型的实例
use_pretrained = True  # 使用已经训练好的参数
net = models.vgg16(pretrained=use_pretrained)
net.eval()  # 设置为推测模式

# 对输入图片进行预处理的类
class BaseTransform():
    """
    调整图片的尺寸,并对颜色进行规范化。

    Attributes
    ----------
    resize : int
       指定调整尺寸后图片的大小
    mean : (R, G, B)
       各个颜色通道的平均值
    std : (R, G, B)
       各个颜色通道的标准偏差
    """

    def __init__(self, resize, mean, std):
        self.base_transform = transforms.Compose([
            transforms.Resize(resize),  #将较短边的长度作为resize的大小
            transforms.CenterCrop(resize),  #从图片中央截取resize × resize大小的区域
            transforms.ToTensor(),  #转换为Torch张量
            transforms.Normalize(mean, std)  #颜色信息的正规化
        ])

    def __call__(self, img):
        return self.base_transform(img)

# 根据输出结果对标签进行预测的后处理类
class ILSVRCPredictor():
    """
    根据ILSVRC数据,从模型的输出结果计算出分类标签

    Attributes
    ----------
    class_index : dictionary
           将类的index与标签名关联起来的字典型变量
    """

    def __init__(self, class_index):
        self.class_index = class_index

    def predict_max(self, out):
        """
        获得概率最大的ILSVRC分类标签名

        Parameters
        ----------
        out : torch.Size([1, 1000])
            从Net中输出结果

        Returns
        -------
        predicted_label_name : str
            预测概率最高的分类标签的名称
        """
        maxid = np.argmax(out.detach().numpy())
        predicted_label_name = self.class_index[str(maxid)][1]

        return predicted_label_name
# 载入ILSVRC的标签信息,并生成字典型变量
ILSVRC_class_index = json.load(open('./data/imagenet_class_index.json', 'r'))

# 生成ILSVRCPredictor的实例
predictor = ILSVRCPredictor(ILSVRC_class_index)

# 读取输入的图像
image_file_path = './data/jww2.webp'
img = Image.open(image_file_path)  # [ 高度 ][ 宽度 ][ 颜色RGB]

# 完成预处理后,添加批次尺寸的维度
resize = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = BaseTransform(resize, mean, std)  #创建预处理类
img_transformed = transform(img)  # torch.Size([3, 224, 224])
inputs = img_transformed.unsqueeze_(0)  # torch.Size([1, 3, 224, 224])

# 输入数据到模型中,并将模型的输出转换为标签
out = net(inputs)  # torch.Size([1, 1000])
result = predictor.predict_max(out)

# 输出预测结果
print("输入图像的预测结果:", result)

我识别的是一只吉娃娃的图片,结果正确,Chihuahua。

现有的模型已经可以正常工作了,下面就是添加新的分类了,这里使用了蚂蚁和蜜蜂。把 1000 个分类改为了 2个分类。
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)




class ImageTransform():
    """
    图像的预处理类。训练时和推测时采用不同的处理方式
    对图像的大小进行调整,并将颜色信息标准化
    训练时采用 RandomResizedCrop 和 RandomHorizontalFlip 进行数据增强处理


    Attributes
    ----------
    resize : int
       指定调整后图像的尺寸
    mean : (R, G, B)
        各个颜色通道的平均值
    std : (R, G, B)
        各个颜色通道的标准偏差
    """

    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(
                    resize, scale=(0.5, 1.0)), #数据增强处理
                transforms.RandomHorizontalFlip(),  #数据增强处理
                transforms.ToTensor(),  # 转换为张量
                transforms.Normalize(mean, std)  # 归一化
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),  #调整大小
                transforms.CenterCrop(resize),  #从图像中央截取resize×resize大小的区域
                transforms.ToTensor(), #转换为张量
                transforms.Normalize(mean, std)  #归一化
            ])
        }

    def __call__(self, img, phase='train'):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
            指定预处理所使用的模式
        """
        return self.data_transform[phase](img)

#  创建用于保存蚂蚁和蜜蜂的图片的文件路径的列表变量


def make_datapath_list(phase="train"):
    """
    创建用于保存数据路径的列表

    Parameters
    ----------
    phase : 'train' or 'val'
        指定是训练数据还是验证数据

    Returns
    -------
    path_list : list
       保存了数据路径的列表
    """

    rootpath = "./data/hymenoptera_data/"
    target_path = osp.join(rootpath+phase+'/**/*.jpg')
    print(target_path)

    path_list = []  #  保存到这里

    #  使用 glob 取得包括示例目录的文件路径
    for path in glob.glob(target_path):
        path_list.append(path)

    return path_list


class HymenopteraDataset(data.Dataset):
    """
    蚂蚁和蜜蜂图片的Dataset类,继承自PyTorch的Dataset类

    Attributes
    ----------
    file_list : 列表
        列表中保存了图片路径
    transform : object
        预处理类的实例
    phase : 'train' or 'test'
        指定是学习还是验证
    """

    def __init__(self, file_list, transform=None, phase='train'):
        self.file_list = file_list  # 文件路径列表
        self.transform = transform  # 预处理类的实例
        self.phase = phase  # 指定是train 还是val

    def __len__(self):
        '''返回图片张数'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''
        获取预处理完毕的图片的张量数据和标签
        '''

        #载入第index张图片
        img_path = self.file_list[index]
        img = Image.open(img_path) #[高度][宽度][颜色RGB]

        #对图片进行预处理
        img_transformed = self.transform(
            img, self.phase)  # torch.Size([3, 224, 224])

        #从文件名中抽取图片的标签
        if self.phase == "train":
            label = img_path[30:34]
        elif self.phase == "val":
            label = img_path[28:32]

      #将标签转换为数字
        if label == "ants":
            label = 0
        elif label == "bees":
            label = 1

        return img_transformed, label

#  执行
train_list = make_datapath_list(phase="train")
val_list = make_datapath_list(phase="val")

#执行
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
train_dataset = HymenopteraDataset(
    file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')

val_dataset = HymenopteraDataset(
    file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')

#指定小批次尺寸
batch_size = 32

#创建DataLoader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

#集中到字典变量中
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

#确认执行结果
batch_iterator = iter(dataloaders_dict["train"])  #转换成迭代器
inputs, labels = next(
    batch_iterator) #取出第一个元素

# 载入已经学习完毕的VGG−16模型
#创建VGG−16模型的实例
use_pretrained = True #指定使用已经训练好的参数
net = models.vgg16(pretrained=use_pretrained)

#指定使用已经训练好的参数
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

#设定为训练模式
net.train()

print('网络设置完毕 :载入已经学习完毕的权重,并设置为训练模式')

# #设置损失函数
criterion = nn.CrossEntropyLoss()

params_to_update = []

#需要学习的参数名称
update_param_names = ["classifier.6.weight", "classifier.6.bias"]

#除了需要学习的那些参数外,其他参数设置为不进行梯度计算,禁止更新
for name, param in net.named_parameters():
    if name in update_param_names:
        param.requires_grad = True
        params_to_update.append(param)
        print(name)
    else:
        param.requires_grad = False

optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)


def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    #epoch循环
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        # 每个epoch中的学习和验证循环
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  #将模式设置为训练模式
            else:
                net.eval()   #将模式设置为验证模式

            epoch_loss = 0.0  #epoch的合计损失
            epoch_corrects = 0 #epoch的正确答案数量

            #为了确认训练前的验证能力,省略epoch=0时的训练
            if (epoch == 0) and (phase == 'train'):
                continue

            #载入数据并切取出小批次的循环
            for inputs, labels in tqdm(dataloaders_dict[phase]):

                #初始化optimizer
                optimizer.zero_grad()

                #计算正向传播(forward)
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels) #计算损失
                    _, preds = torch.max(outputs, 1)  #预测标签
                    
  
                    ##训练时的反向传播
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    #计算迭代的结果
                    # 计算迭代的结果
                    epoch_loss += loss.item() * inputs.size(0)  
                    # 更新正确答案数量的总和
                    epoch_corrects += torch.sum(preds == labels.data)

            #显示每个epoch的loss和正解率
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double(
            ) / len(dataloaders_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

num_epochs=2
train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

在这里插入图片描述
通过运行结果可以看到,首次没有训练直接在原始模型进行测试,正确率 33%,第二轮,经过 8 次迭代学习,正确率提高到 72%,这里比较奇怪的是验证集的正确率更高。原因是训练集做了数据增广,有些图片是变形的,所以识别起来更加困难。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1623613.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

重要提醒!别再这样搭建帮助中心系统了

你们有没有这样的经历呢?当你使用某产品或服务时遇到问题,打开产品或服务的帮助中心,但界面设计太复杂,内容搜出来的内容多但是混乱不一致。或者更糟糕的是,帮助中心的界面设计看得人眼花缭乱。 所以,反思一…

全长直线度的检查方法和设备

关键字:全长直线度, 直线度测量仪,直线度测量机,直线度检测,直线度检测设备, 全长直线度的检测是确保机械部件、导轨、机床工作台等在全长范围内直线运动精度的重要手段。以下是一些常用的全长直线度检测方法和设备: --------直角尺和水平仪--------:…

bit、进制、位、时钟(窗口)、OSI七层网络模型、协议、各种码

1.bit与进制 (个人理解,具体电路是非常复杂的) 物理层数据流,bit表示物理层数据传输单位, 一个电路当中,通过通断来表示数字1和0 两个电路要通讯,至少要两根线,一根作为电势参照…

浓眉大眼的Apple开源OpenELM模型;IDM-VTON试衣抱抱脸免费使用;先进的语音技术,能够轻松克隆任何人的声音

✨ 1: openelm OpenELM是苹果机器学习研究团队发布的高效开源语言模型家族 OpenELM是苹果机器学习研究团队开发的一种高效的语言模型,旨在推动开放研究、确保结果的可信赖性、允许对数据和模型偏见以及潜在风险进行调查。其特色在于采用了一种分层缩放策略&#x…

定时器介绍

定时器简介 一、周期定时功能二、PWM功能三、脉冲捕获四、事件计数五、扩展触发功能 一、周期定时功能 定时器的时钟为所选时钟源LRC、OSC、HRC、PLL通过定时器内的预分频器TMRDIV分频得到。 二、PWM功能 包括向上、下、中央计数方式,以向上计数为例计数和引脚产生…

使用excel文件生成sql脚本

目录 1、excel文件脚本变量2、公式示例 前言:在系统使用初期有一些基础数据需要从excel中导入到数据库中,直接导入的话可能有些字段用不上,所以就弄一个excel生成sql的导入脚本,这样可以将需要的数据填到指定的列即可生成sql。 1、…

前端路由的实现原理

当谈到前端路由时,指的是在前端应用中管理页面导航和URL的机制。前端路由使得单页应用(Single-Page Application,SPA)能够在用户与应用交互时动态地加载不同的视图,而无需每次都重新加载整个页面。 在前端开发中&…

【VTKExamples::Meshes】第十八期 OBBDicer

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例OBBDicer,并解析接口vtkOBBDicer,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 1. …

AT7456E 贴片TSSOP-28 新版本 OSD字符叠加芯片

AT7456E OSD(On-Screen Display)叠加芯片的应用领域相当广泛,主要用于在视频信号上传递附加信息。根据您提供的信息[2],以下是AT7456E的一些典型应用领域: 1.无人机:用于在无人机的视频传输中叠加关键信息…

NIKKE胜利女神妮姬1.5周年(PC)怎么下载一键下载安装教程一看就会

NIKKE胜利女神妮姬1.5周年(PC)怎么下载?一键下载安装教程一看就会 近日一款新型FPS游戏NIKKE引起了游戏爱好者们的热议,这款游戏是由Shift Up公司开发的一款二次元风格美少女射击类RPG游戏。玩家可以通过抽卡获取不同的角色,并通过主线支线关…

windows下git提交修改文件名大小写提交无效问题

windows系统不区分大小写,以及git提交忽略大小写,git仓库已存在文件A.js,本地修改a.js一般是没有提交记录的,需要手动copy一份出来A.js,再删除A.js文件提交仓库删除后,再提交修改后的a.js文件。 windows决…

Next.js 14 App Router引入 farmer-motion 初始化异常解决,顺带学点知识

前言 farmer-motion 是一个非常好用的动画库,当然用来做组件切换和路由切换过渡更不在话下。 记录一下,Next.js 14 App Router 下引入初始化异常的解决姿势,顺带扯一下 next.js 的知识点; 问题 过渡组件代码 我们拿 farmer-m…

SAP DMS修改文档操作简介

修改DMS文档的事物代码是—CV02N,同样的删除、审批都是用的CV02N对文档进行操作 1、文档的审批,根据我们后台对文档版本的配置通过CV02N对这个文档状态的一个变更 当审批后系统就会显示绿灯,如下图 2、文档的标记删除 我们在CV02N的界面中直接点击删除标记即可。 点击后…

Error opening file a bytes-like object is required,not ‘NoneType‘

错误显示,打开的是一个无效路径的文件 查看json文件内容,索引的路径与json文件保存的路径不同 方法:使用python脚本统一修改json文件路径 import json import os import argparse import cv2 from tqdm import tqdm import numpy as np impo…

算法设计优化——有序向量二分查找算法与Fibonacci查找算法

文章目录 0.概述1.语义定义2. 二分查找(版本A)2.1 原理2.2 实现2.3 复杂度2.4 查找长度 3.Fibonacci查找3.1 改进思路3.2 黄金分割3.3 实现3.4 复杂度分析3.5 平均查找长度 4. 二分查找(版本B)4.1 改进思路4.2 实现4.3 性能4.4 进…

基于CANoe从零创建以太网诊断工程(2)—— TCP/IP Stack 配置的三种选项

🍅 我是蚂蚁小兵,专注于车载诊断领域,尤其擅长于对CANoe工具的使用🍅 寻找组织 ,答疑解惑,摸鱼聊天,博客源码,点击加入👉【相亲相爱一家人】🍅 玩转CANoe&…

手撕netty源码(二)- 初始化ServerBootstrap

文章目录 前言一、ServerBootstrap 的创建和初始化1.1 创建1.2 初始化group1.3 初始化channel1.3 初始化option和attr1.4 初始化handler 和 childHandler 总结 前言 processOn文档跳转 接上一篇:手撕netty源码(一)- NioEventLoopGroup 本篇讲…

uni-app app和h5的通信

uni-app一套代码同时打包安卓、iOS、h5,有一些需要app与h5的交互通信,目前做到了这块的业务,记录如下: 1.去declould官网,找到uni_webview.js下载链接,将uni_webview.js文件下载到本地,修改uni_webview.js内部配置,将uni修改为webUni,修改好的文件已放到…

SpringBoot中Bean的创建过程及扩展操作点 @by_TWJ

目录 1. 类含义2. Bean创建过程 - 流程图3. 例子3.1. 可变属性注入到实体中3.2. 模拟Bean创建的例子 1. 类含义 BeanDefinition - 类定义,为Bean创建提供一些定义类的信息。实现类如下: RootBeanDefinition - 类定义信息,包含有父子关系的Be…

智慧健康旅居养老产业,做智慧旅居养老服务的公司

随着社会的进步和科技的飞速发展,传统的养老模式已经无法满足 现代老年人的多元化 需求。智慧健康旅居养老产业应运而生,成为了一种新型的养老模式,旨在为老年人提供更加舒适、便捷、安全的养老生活。随着社会的进步和人口老龄化趋势的加剧&a…