【PyTorch 卷积】实战自定义的图片归类

news2024/11/24 17:20:14

前言

        卷积神经网络是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习的代表算法之一,它通过卷积层、池化层、全连接层等结构,可以有效地处理如时间序列和图片数据等。关于卷积的概念网络上也比较多,这里就不一一描述了。实战为主当然要从实际问题出发,用代码的方式加深印象。在写代码前,我先说一下为什么我要写这篇文章?

        之前我也用 Tensorflow.js 跟着别人试过图片分类,虽然结果是有了,但是对代码的理解和印象并不深刻。后来由于工作业务原因才接触 PyTorch,发现这个框架更好上手,整一圈后就想用这个把之前用得图片也实现一下分类。开始也是看文章实现,但是网上大部分都是用 MNIST 数据集实现的手写字识别,而业务中有时就是一些指定的不规则小众图片识别,所以下面就简单实现一个自定义的图片集归类。

流程

  • 根据自己的定义,收集图片并归类
  • 读取图片数据和归类标签,保存数据集
  • 固定图片大小 (会变形),归一化转张量
  • 定义超参数,损失函数和优化器等
  • 炼丹,重复查看损失值准确率等指标
  • 保存模型参数,加载测试图片分类效果

环境

  • Python 3.8
  • Torch 1.9.0
  • Pillow 10.0
  • Torchvision
  • Numpy
  • Pandas
  • Matplotlib

编码

        写代码前已经把需要的图片做好了分类,上面的依赖包也已经安装完毕。由于只是演示这里没有用预训练模型(ResNet、VGG),因为训练时要用的是 Tensor,所以需要先读取文件夹内的图片先转化为 PIL 的对象数据或 Numpy 数据,然后可以对图片进行调整,最后全都转成 Tensor(也可以跳过 PIL 直接转张量)。这里需要注意的是对灰彩图片通道,不同尺寸图的统一处理,就是灰色图的单通道要通过复制的方式创建三个通道,所以图片设置一样的像素大小。因为在卷积网络中,输入的通道数和输入大小要一致,不然可能在训练中报错。

图片数据生成

        这里就是遍历各个分类文件夹的图片转换为对象信息数据,和提取所有分类,分别保存到指定位置,当然也可以在这里划分训练数据,校验数据,测试数据,需要的可以扩展这里就跳过了。

# -*- coding: utf-8 -*-
import os
import pickle as pkl
import pandas as pd
from PIL import Image

all_cate = []
data_set = []
directory = "./data/train"
for index, data in enumerate(os.walk(directory)):
    root, dirs, files = data

    if index == 0:
        all_cate += dirs
    else:
        sorted(all_cate)

        root_names = root.split("\\")
        dir_name = root_names[-1]

        for img in files:
            img_path = root + "\\" + img
            img_np = Image.open(img_path)
            dict = {}
            dict['img_np'] = img_np
            dict['label'] = all_cate.index(dir_name) + 1
            data_set.append(dict)

# 字典转DataFrame
df = pd.DataFrame(data_set)
pkl.dump(df, open('data/train_dataset.p', 'wb'))
open("data/all_cate.txt", encoding="utf-8", mode="w+").write("\n".join(all_cate))

print("存档数据成功~")
批量数据集标准化

        这里是读取序列化的图片信息,对所有图片统一像素 (一般配置电脑最好在 100px 以内,不然会很卡) 并标准归一化后,转换为 Tensor。然后判断图片通道数,如果是灰色图,可以复制张量三次以创建三个通道,最后通过 torch 的 DataLoader 在训练前完成数据集的加载。

# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import pickle as pkl
from torch.utils.data import Dataset

class DataSet(Dataset):

    def __init__(self, pkl_file):
        df = pkl.load(open(pkl_file, 'rb'))
        self.dataFrame = df

    def __len__(self):
        return len(self.dataFrame)

    def __getitem__(self, item):

        img_np = self.dataFrame.iloc[item, 0]
        label = self.dataFrame.iloc[item, 1]

        transform = transforms.Compose([
            transforms.Resize((100, 100)),  # 根据需要调整图像大小
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])    # 标准归一化, p1.均值  p2.方差
        ])
        image_tensor = transform(img_np)

        if image_tensor.shape[0] == 1:  
            image_tensor = image_tensor.repeat(3, 1, 1)  

        res = {
            'img_tensor': image_tensor,
            'label': torch.LongTensor([label-1])    # 需要实际的索引值
        }

        return res
神经网络模型

        这里创建的是卷积神经网络,接收 3 通道,第一层卷积层卷积核 3x3,输出 25 维张量,通过批标准化(BatchNorm2d)进行归一化处理,最后通过 ReLU 激活函数进行非线性变换。第一层池化使用 2x2 的最大池化操作对卷积后的特征图进行下采样。第二层也是卷积和对应的池化,最后是全连接层。将经过池化的特征图展平,然后通过一个有 1024 个神经元的全连接层,再通过 ReLU 激活函数进行非线性变换。之后是一个有 128 个神经元的全连接层,最后再通过 ReLU 激活函数进行非线性变换,输出 5 个神经元代表分类的概率分布。

# -*- coding: utf-8 -*-
import torch.nn as nn
import torch
import math
import torch.functional as F

class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 25, kernel_size=3),
            nn.BatchNorm2d(25),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(25, 50, kernel_size=3),
            nn.BatchNorm2d(50),
            nn.ReLU(inplace=True)
        )

        self.layer4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(50 * 23 * 23, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 5)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)

        x = self.fc(x)

        return x
开始训练
# -*- coding:utf-8 -*-
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from data_set import DataSet
from torch.autograd import Variable
from utils import *
import cnn
import torch.nn as nn
import numpy as np
import torch.optim as optim

# 定义超参数
batch_size = 1
learning_rate = 0.02
num_epoches = 1

# 加载图片tensor训练集
tain_dataset = DataSet("data/train_dataset.p")
train_loader = DataLoader(tain_dataset, batch_size=batch_size, shuffle=True)

model = cnn.CNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 训练模型
train_loses = []
records = []
for i in range(num_epoches):
    for ii, data in enumerate(train_loader):
        img = data['img_tensor']
        label = data['label'].view(-1)

        optimizer.zero_grad()
        out = model(img)
        loss = criterion(out, label)
        train_loses.append(loss.data.item())
        loss.backward()
        optimizer.step()

        if ii % 50 == 0:
            print('epoch: {}, loop: {}, loss: {:.4}'.format(i, ii, np.mean(train_loses)))

        records.append([np.mean(train_loses)])

# 绘制模型的损失,准确率走势图
train_loss = [data[0] for data in records]
plt.plot(train_loss, label = 'Train Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 模型评估(略)
# model.eval()

# 模型保存
torch.save(model, 'params/cnn_imgs_02.pkl')
模型检测

        训练完成保存参数到本地,下面就是将加载进的参数来测试其他图片的分类效果,同样的也是将指定图片和训练时一样的转换操作,最后将预测结果取出最大分布索引值,根据索引就可以匹配出分类名称了。另一个是工具函数,将 tensor 格式的图片在预测结果后显示在 pyplot 中。

# -*- coding:utf-8 -*-
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from data_set import DataSet
from utils import *
import torchvision
from PIL import Image
from torchvision import transforms
import cnn


def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

model = torch.load("params/cnn_imgs_02.pkl")

img_path= "imgs/05.jpg"
img_np = Image.open(img_path)
transform = transforms.Compose([
    transforms.Resize((100, 100)),  
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  
])
image_tensor = transform(img_np)

# 如果是灰度图片
if image_tensor.shape[0] == 1:  
    image_tensor = image_tensor.repeat(3, 1, 1)  

image_tensor = image_tensor.view(-1, 3, 100, 100)

predict = model(image_tensor)
indices = torch.max(predict, 1)[1].item()

all_cate = []
for line in open("data/all_cate.txt", encoding="utf-8", mode="r"):
    all_cate.append(line.strip())

cate_name = ""
try:
    cate_name = all_cate[indices]
except ValueError:
    cate_name = "未知"

print("识别结果是:", cate_name)
# imshow(torchvision.utils.make_grid(image_tensor))
# 原图显示
img_np.show()
exit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1161682.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苹果手机黑屏了怎么恢复正常?这些修复方法记得收藏好!

苹果手机黑屏是一种常见的手机故障,很多人在遇到手机突然黑屏的情况时感到束手无策。手机黑屏会干扰用户的正常使用,带来不便,并给用户带来不好的体验。苹果手机黑屏了怎么恢复正常?本文将为大家详细介绍修复苹果手机黑屏的方法。…

3.网络之UDP

UDP协议 文章目录 UDP协议1. UDP概述2. UDP报文格式3. UDP传输限制4. UDP校验和4.1 CRC 循环冗余校验算法4.2 md5 校验算法 1. UDP概述 UDP(UserDatagramProtocol)是一个简单的面向消息的传输层协议,尽管UDP提供标头和有效负载的完整性验证&a…

Linux指令【下】

目录 时间 date 时间戳 cal 查找 find which whereis grep uniq 打包压缩 zip/unzip tar uname 其他热键 关机 系统互传 linux和Windows Linux和Linux 时间 date 用法:date[option] [format] 选项功能%Y年%d日%m月%H时%M分%s秒%X%H:%M%S%F%Y-%…

「Java开发指南」如何用MyEclipse搭建Spring MVC应用程序?(一)

本教程将指导开发者如何生成一个可运行的Spring MVC客户应用程序,该应用程序实现域模型的CRUD应用程序模式。在本教程中,您将学习如何: 从数据库表的Scaffold到现有项目部署搭建的应用程序 使用Spring MVC搭建需要MyEclipse Spring或Bling授…

【MySQL数据库】 一

本文主要介绍了关系型数据库和非关系数据库的区别,以及主流的关系型数据库mysql的安装 , 以及mysql数据库客户端-服务器的结构. 一.数据库的分类 我们可以简单的把数据库看成是一类软件 数据库分成两大类 1.关系型数据库 通常以表格的方式来组织 2.非关系型数据库 通常以键值…

电脑关机很慢?这几个优化技巧请收好!

当我们使用电脑时,一个令人不快的问题是,关机变得异常缓慢。电脑在关机时可能需要很长时间,甚至让人感到沮丧。这不仅是时间浪费,还可能表明系统存在问题。在本文中,我们将介绍四种解决电脑关机很慢的方法,…

怎么向国外客户催单?这样做既有效又不让客户反感

外贸业务员接单的过程其实是一场利益博弈的过程,而且外贸销售永远靠结果说话,所以无论你的客户如何承诺,甚至有时候都已经签订了定单合同做了PI,但客户钱没到账公司,一切就有可能归零。 01 心态一定要稳住 对于每个有…

如何在业务代码中优雅地使用责任链模式

通过使用责任链模式,我们可以更加灵活和优雅地处理请求,降低代码之间的耦合度,提高代码的可维护性和可扩展性。在一些具有复杂业务逻辑或需要动态处理请求的场景下,使用责任链模式将是一个很好的选择。本文将通过一个具体的示例来…

【错误解决方案】ModuleNotFoundError: No module named ‘tensorboardX‘

1. 错误提示 在python程序中,尝试导入一个名为tensorboardX的模块,但Python提示找不到这个模块。 错误提示:ModuleNotFoundError: No module named ‘tensorboardX‘ 2. 解决方案 在python出现中,遇到这个问题是Python无法找到…

无需编程技术,快速搭建个人网站

如果你想拥有一个属于自己的个人网站,但又没有任何编程经验,别担心,我们今天将为你介绍一个简单的方法,让你轻松搭建网站,无需任何编程知识。让我们一起来看看吧! 在乔拓云建站工具中,自带了许多…

TypeScript之装饰器

一、是什么 装饰器是一种特殊类型的声明,它能够被附加到类声明,方法, 访问符,属性或参数上 是一种在不改变原类和使用继承的情况下,动态地扩展对象功能 同样的,本质也不是什么高大上的结构,就…

图纸管理制度《八》设计图纸管理制度

第一章 总则 第1条 目的。为做好设计图纸的管理工作,使其收发及时、手续齐全、废图绝迹、不遗失、无差错,特制定本办法。 第2条 适用范围。本办法适用于企业所有工程项目的图纸管理工作。 第3条 相关部门及人员职责 (1) 工程技术部负责图纸管理的监督…

双十一百亿美元补贴,AWS阿里云腾讯云华为云国际版钜惠

双十一来袭!阿里云/腾讯云/华为云国际站该怎么玩?九河云(双十一特大促销,低至5.18折 (9he.com))这次双十一活动汇聚了一系列前所未有的优惠,不仅能享受服务器和CDN的超值折扣,还有机会赢取华为M…

智慧工地建造平台源码、智慧化工地云平台源码

概述:智慧工地管理平台充分运用数字化技术,聚焦施工现场岗位一线,依托物联网、互联网、AI等技术,围绕施工现场管理的人、机、料、法、环五大维度,以及施工过程管理的进度、质量、安全三大体系为基础应用,实…

Spring Cloud Alibaba中Nacos的安装(Windows平台)以及服务的发现

Spring Cloud Alibaba中Nacos的安装(Windows平台)以及服务的发现 下载安装Nacos解压启动验证是否启动搭建一个简单的Spring Cloud Alibaba项目Spring Cloud Alibaba 以及 Nacos的引入如何选择对应的版本 服务的注册Nacos相关组件的说明 下载安装Nacos G…

Python如何解析json对象?

目录 一、JSON简介 二、Python的json模块 1. 加载JSON数据 2. 生成JSON数据 三、处理复杂的JSON数据 四、自定义JSON解析器 五、注意事项和最佳实践 六、总结 JSON(JavaScript Object Notation)作为一种轻量级的数据交换格式,在网络通…

https原理

首先说一下几个概念:对称加密、非对称加密 对称加密: 客户端和服务端使用同一个秘钥,分两种情况: 1、所有的客户端和服务端使用同一个秘钥,这个秘钥被泄漏后数据不再安全 2、每个客户端生成一个秘钥&…

嵌入式每日500(4)231104 (Flash类型定义、Flash常量定义、Flash函数)

文章目录 1.Flash类型定义(两个结构体)2.Flash常量定义(3种)3.Flash函数(31个,FLASH分为两个区,一个是普通的存储空间,一个是选项字节OB,函数名里带OB的就是对选项字节空…

一文速通Sentinel熔断及降级规则

目录 基本介绍 熔断模式 状态机的三个状态 熔断降级规则 断路器熔断策略 慢调用 异常比例 异常数 基本介绍 熔断模式 主要是参考电路熔断,如果一条线路电压过高,保险丝会熔断,防止火灾。放到我们的系统中,如果某个目标…

Azure 机器学习 - 无代码自动机器学习的预测需求

了解如何在 Azure 机器学习工作室中使用自动化机器学习在不编写任何代码行的情况下创建时序预测模型。 此模型将预测自行车共享服务的租赁需求。 关注TechLead,分享AI全维度知识。作者拥有10年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕…