PyTorch 实现食物图像分类实战:从数据处理到模型训练

news2025/4/26 7:01:08

一、简介

在计算机视觉领域,图像分类是一项基础且重要的任务,广泛应用于智能安防、医疗诊断、电商推荐等场景。本文将以食物图像分类为例,基于 PyTorch 框架,详细介绍从数据准备、模型构建到训练测试的全流程,帮助读者深入理解深度学习图像分类的实践过程。

 二、原理

 三、代码示例

1、数据文件路径准备

在实际项目中,原始图像数据通常按类别存储在不同文件夹下。代码中的train_test_file函数用于遍历数据文件夹,将图像文件路径及其对应的类别标签写入文本文件,方便后续数据加载:

import os
def train_test_file(root, dir):
    file_txt = open(dir+'.txt','w')
    path = os.path.join(root, dir)
    for roots, directories, files in os.walk(path):#os.walk(path)代表开始去遍历path路径下的文件
        if len(directories) != 0:
            dirs = directories
        else:
            now_dir = roots.split('\\')
            for file in files:
                path_1 = os.path.join(roots,file)
                print(path_1)
                file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')
    file_txt.close()
root = r'.\食物分类\food_dataset2'
train_dir = 'train'
test_dir = 'test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)

该函数通过os.walk递归遍历文件夹,将每个图像文件的绝对路径与对应的类别索引(通过文件夹名称顺序确定)写入.txt文件,格式为图像路径 标签。

2、自定义数据集类

import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset,DataLoader #用于处理数据集
from torchvision import transforms
data_transforms = {#字典
    'train':
        transforms.Compose([#对图片做预处理的,组合
        transforms.Resize([256,256]),#数据进行改变大小
        transforms.ToTensor(),#数据转换为tensor,默认把通道维度放在前面
    ]),
    'valid':
        transforms.Compose([
        transforms.Resize([256, 256]),
        transforms.ToTensor(),
    ]),
}#数组增强

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path#为了将外部空间的路径传递给共享空间,以便于后期可以使用
        self.imgs = []
        self.labels = []
        self.transform = transform
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)#图像的路径
                self.labels.append(label)#标签,还不是tensor
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])#读取到图片数据,还不是tensor,BGR
        if self.transform:#将pil图像数据转换为tensor
            image = self.transform(image)

        label = self.labels[idx]#label还不是tensor
        label = torch.from_numpy(np.array(label, dtype = np.int64))#label也转换为tensor
        return image, label

 food_dataset类实现了__len__和__getitem__两个关键方法:

__len__返回数据集的样本总数;

__getitem__根据索引读取图像文件,应用数据变换(如调整大小、转换为张量),并将标签转换为torch.Tensor格式后返回。

3、数据加载器配置 

#training_data包含了本次训练需要的全部数据集
training_data = food_dataset(file_path = './train.txt',transform = data_transforms['train'])
test_data = food_dataset(file_path = './test.txt',transform = data_transforms['valid'])
#training_data需要具备索引的功能,还要确保数据是tensor
train_dataloader = DataLoader(training_data, batch_size=64,shuffle = True)#64个图片为一个包,shuffle = True用于将数据进行打乱
test_dataloader = DataLoader(test_data, batch_size = 64,shuffle = True)

 DataLoader的batch_size参数指定每个批次包含的样本数量,shuffle=True表示在每个 epoch 训练前打乱数据顺序,有助于提高模型的泛化能力

4、搭建卷积神经网路模型

'''定义神经网络'''
from torch import nn   #导入神经网络模块

class CNN(nn.Module):
    def __init__(self):   #python基础关于类,self类自已本身
        super(CNN,self).__init__()   #继承的父类初始化
        self.conv1=nn.Sequential(    #将多个层组合成一起。创建了一个容器,将多个网络合在一起
            nn.Conv2d(       #2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据
                in_channels=3,   #、图像通道个数,1表示灰度图(确定了卷积核 组中的个数)
                out_channels=16,   # 要得到几多少个特征图,卷积核的个数
                kernel_size=5,     # 卷积核大小,5*5
                stride=1,          # 步长
                padding=2,        #一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好。那padding改如何
            ),                     # 输出的特征图为(16,28,28)
            nn.ReLU(),            # relu层,不会改变特征图的大小
            nn.MaxPool2d(kernel_size=2),        #进行池化操作(2x2 区域),输出结果为:(16,128,128)
        )
        self.conv2=nn.Sequential(   #输入
            nn.Conv2d(16,32,5,1,2),   # 输出(32 128 128)
            nn.ReLU(),
            nn.Conv2d(32,32,5,1,2),  # 输出(32 128 128)
            nn.ReLU(),
            nn.MaxPool2d(2),     #输出(32,64,64)
        )
        self.conv3=nn.Sequential(      #输入(32 64 64)
            nn.Conv2d(32,128,5,1,2),   #(128 64 64)
            nn.ReLU(),
        )
        self.out=nn.Linear(128*64*64,20)    #全连接层得到的结果

    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=x.view(x.size(0),-1)
        output=self.out(x)
        return output

model = CNN().to(device)
print(model)

5、训练与测试函数实现

def train(dataloader,model,loss_fn,optimizer):
    model.train()   #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
#pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
#一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()
    batch_size_num=1
    for X,y in dataloader:       #其中batch为每一个数据的编号,X是打包好的每一个数据包
        X,y=X.to(device),y.to(device)    #把训练数据集和标签传入cpu或GPU
        pred=model.forward(X)    #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值
        loss=loss_fn(pred,y)     #通过交叉熵损失函数计算损失值loss
        # Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad()    #梯度值清零
        loss.backward()          #反向传播计算得到每个参数的梯度值w
        optimizer.step()         #根据梯度更新网络w参数

        loss_value=loss.item()   #从tensor数据中提取数据出来,tensor获取损失值
        if batch_size_num %1 ==0:
            print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')
        batch_size_num+=1

def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset)
    num_batches=len(dataloader)  #打包的数量
    model.eval()  #测试,w就不能再更新。
    test_loss,correct=0,0
    with torch.no_grad():    #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。
        for X,y in dataloader:
            X,y=X.to(device),y.to(device)
            pred=model.forward(X)
            test_loss+=loss_fn(pred,y).item()   #test_loss是会自动累加每一个批次的损失值
            correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
            a=(pred.argmax(1)==y)   #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值
            b=(pred.argmax(1)==y).type(torch.float)
    test_loss /=num_batches
    correct /= size
    print(f'Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')

 6、模型训练与评估

loss_fn=nn.CrossEntropyLoss()   #创建交叉熵损失函数对象,因为食物的类别是20
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)   #创建一个优化器,SGD为随机梯度下降算法
# #params:要训练的参数,一般我们传入的都是model.parameters()#
# lr:learning_rate学习率,也就是步长

#loss表示模型训练后的输出结果与,样本标签的差距。如果差距越小,就表示模型训练越好,越逼近干真实的模型。

# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)

epochs=1
for t in range(epochs):
    print(f"Epoch {t+1}\n---------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader,model,loss_fn)

 7、运行结果

四、总结与优化方向

本文完整展示了基于PyTorch的食物图像分类项目流程,涵盖数据处理、模型构建和训练测试等核心环节。然而,当前模型仍有优化空间:

  1. 数据增强:增加更多数据增强策略(如随机裁剪、颜色抖动)以扩充数据集;
  2. 模型优化:尝试更复杂的预训练模型(如ResNet、VGG)或调整超参数(学习率、batch size);
  3. 正则化:添加Dropout或L2正则化防止过拟合。

通过不断改进和实践,图像分类模型的准确率和泛化能力将得到进一步提升。希望本文能为读者在深度学习图像分类领域的学习和实践提供有益参考。

 

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2343078.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt —— 在Linux下试用QWebEngingView出现的Js错误问题解决(附上四种解决办法)

错误提示:js: A parser-blocking, cross site (i.e. different eTLD+1) script, https:xxxx, is invoked via document.write. The network request for this script MAY be blocked by the browser in this or a future page load due to poor network connectivity. If bloc…

命名空间(C++)

命名空间主要用于大型项目中。 局部命名在该局部会覆盖全局命名。C语言中唯一一种在局部调用全局相同命名的全局变量的方式:指针在C中可以用作用域运算符来访问全局变量,作用域运算符的前面可以是作用域也可以是类。 命名空间实际上是对全局作用域的再次…

LabVIEW圆锥滚子视觉检测系统

基于LabVIEW平台的视觉检测系统提高圆锥滚子内组件的生产质量和效率。通过集成高分辨率摄像头和先进的图像处理算法,系统能够自动识别和分类产品缺陷,从而减少人工检查需求,提高检测的准确性和速度。 ​​ ​ 项目背景 随着制造业对产品质…

OpenAI 推出「轻量级」Deep Research,免费用户同享

刚刚,OpenAI 正式上线了面向所有用户的「轻量级」Deep Research 版本,意味着即便没有付费订阅,也能体验这一强大工具的核心功能。 核心差异:o4-mini vs. o3 模型迭代 传统的深度研究功能基于更大规模的 o3 模型。轻量级版本则改以…

罗伯·派克:Go语言创始者的极客人生

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 罗伯派克:Go语言创始者的极客人生 一、传奇程序员的成长历程 1. 早年经历…

小白工具视频转MPG, 功能丰富齐全,无需下载软件,在线使用,超实用

在视频格式转换需求日益多样的今天,小白工具网的在线视频转 MPG 功能https://www.xiaobaitool.net/videos/convert-to-mpg/ )脱颖而出,凭借其出色特性,成为众多用户处理视频格式转换的优质选择。 从格式兼容性来看,它支…

day32 学习笔记

文章目录 前言一、霍夫变换二、标准霍夫变换三、统计概率霍夫变换四、霍夫圆变换 前言 通过今天的学习,我掌握了霍夫变换的基本原本原理及其在OpenCV中的应用方法 一、霍夫变换 霍夫变换是图像处理中的常用技术,主要用于检测图像中的直线,圆…

CentOS 7上Memcached的安装、配置及高可用架构搭建

Memcached是一款高性能的分布式内存缓存系统,常用于加速动态Web应用的响应。本文将在CentOS 7上详细介绍Memcached的安装、配置,以及如何实现Memcached的高可用架构。 (1)、搭建memcached 主主复制架构 Memcached 的复制功能支持…

如何让 HTML 文件嵌入另一个 HTML 文件:详解与实践

目录 一、为什么需要在HTML中嵌入其他HTML文件? 二、常用的方法概览 三、利用 1. 基本原理 2. 使用场景 3. 优缺点 4. 实践示例 5. 适用建议 四、利用JavaScript动态加载内容 1. 原理简介 2. 实现步骤 示例代码 3. 优缺点分析 4. 应用场景 5. 实践建…

人工智能与机器学习:Python从零实现逻辑回归模型

🧠 向所有学习者致敬! “学习不是装满一桶水,而是点燃一把火。” —— 叶芝 我的博客主页: https://lizheng.blog.csdn.net 🌐 欢迎点击加入AI人工智能社区! 🚀 让我们一起努力,共创…

windows服务器及网络:搭建FTP服务器

前言:(各位大佬们,昨天太忙了,整得没有发布昨天那该写的那一篇,属实有点可惜的说QAQ,不过问题已经解决,我又回来啦) 今天我要介绍的是在Windows中关于搭建FTP服务器的流程与方法 注…

欧拉计划 Project Euler56(幂的数字和)题解

欧拉计划 Project Euler 56 题解 题干思路code 题干 思路 直接暴力枚举即可&#xff0c;用c要模拟大数的乘法&#xff0c;否则会溢出 code // 972 #include <bits/stdc.h>using namespace std;using ll long long;string mul(const string &num1, int num2) {int…

C++初窥门径

const关键字 一、const关键字 修饰成员变量 常成员变量&#xff1a;必须通过构造函数的初始化列表进行初始化&#xff0c;且初始化后不可修改。 示例&#xff1a; class Student { private: const int age; // 常成员变量 public: Student(string name, int age) : age(ag…

AlarmClock4.8.4(官方版)桌面时钟工具软件下载安装教程

1.软件名称&#xff1a;AlarmClock 2.软件版本&#xff1a;4.8.4 3.软件大小&#xff1a;187 MB 4.安装环境&#xff1a;win7/win10/win11(64位) 5.下载地址&#xff1a; https://www.kdocs.cn/l/cdZMwizD2ZL1?RL1MvMTM%3D 提示&#xff1a;先转存后下载&#xff0c;防止资…

白鲸开源WhaleStudio与崖山数据库管理系统YashanDB完成产品兼容互认证

近日&#xff0c;北京白鲸开源科技有限公司与深圳计算科学研究院联合宣布&#xff0c;双方已完成产品兼容互认证。此次认证涉及深圳计算科学研究院自主研发的崖山数据库管理系统YashanDB V23和北京白鲸开源科技有限公司的核心产品WhaleStudio V2.6。经过严格的测试与验证&#…

【金仓数据库征文】- 金融HTAP实战:KingbaseES实时风控与毫秒级分析一体化架构

文章目录 引言&#xff1a;金融数字化转型的HTAP引擎革命一、HTAP架构设计与资源隔离策略1.1 混合负载物理隔离架构1.1.1 行列存储分区策略1.1.2 四级资源隔离机制 二、实时流处理与增量同步优化2.1 分钟级新鲜度保障2.1.1 WAL日志增量同步2.1.2 流计算优化 2.2 物化视图实时刷…

Windows与CasaOS跨平台文件同步:SyncThing本地部署与同步配置流程

文章目录 前言1. 添加镜像源2. 应用安装测试3. 安装syncthing3.1 更新应用中心3.2 SyncThing安装与配置3.3 Syncthing使用演示 4. 安装内网穿透工具5. 配置公网地址6. 配置固定公网地址 推荐 ​ 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽…

59、微服务保姆教程(二)Nacos--- 微服务 注册中心 + 配置中心

Nacos— 微服务 注册中心 + 配置中心 一.什么是Nacos? Nacos是阿里的一个开源产品,是针对微服务架构中的服务发现、配置管理、服务治理的综合型解决方案。 Nacos核心定位是“一个更易于帮助构建云原生应用的动态服务发现、配置和服务管理平台”,也就是我们的注册中心和配…

第一部分:git基本操作

目录 1、git初识 1.1、存在的问题 1.2、版本控制器 1.3、git安装 1.3.1、CentOS平台 1.3.2、ubuntu平台 2、git基本操作 2.1、创建仓库 2.2、配置git 3、工作区、暂存区、版本库 4、基本操作 4.1、场景一 4.2、场景二 4.3、修改文件 5、版本回退 6、撤销修改 …

《一文读懂Transformers库:开启自然语言处理新世界的大门》

《一文读懂Transformers库:开启自然语言处理新世界的大门》 GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. HF-Mirror Hello! Transformers快速入门 pip install transformers -i https:/…