# 基于PyTorch的食品图像分类系统:从训练到部署全流程指南

news2025/4/23 5:42:00

基于PyTorch的食品图像分类系统:从训练到部署全流程指南

本文将详细介绍如何使用PyTorch框架构建一个完整的食品图像分类系统,涵盖数据预处理、模型构建、训练优化以及模型保存与加载的全过程。

1. 系统概述

本系统实现了一个基于卷积神经网络(CNN)的食品图像分类器,主要特点包括:

  • 支持20种不同食品的分类
  • 使用数据增强提高模型泛化能力
  • 实现了完整的训练-验证-测试流程
  • 提供模型保存与加载功能

2. 数据准备与预处理

2.1 数据增强策略

在这里插入图片描述

我们为训练集和验证集分别设计了不同的数据增强策略:

data_transforms = {'train':  # 训练集  也可以使用PIL库  smote 训练集
    transforms.Compose([  # transforms.Compose用于将多个图像预处理操作整合在一起
        transforms.Resize([300, 300]),  # 使图像变换大小
        transforms.RandomRotation(45),  # 随机旋转,-42到45度之间随机选
        transforms.CenterCrop(256),  # 从中心开始裁剪[256.256]
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平旋转,随机概率为0.5
        transforms.RandomVerticalFlip(p=0.5),  # 随机垂直旋转,随机概率0.5
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 随机改变图像参数,参数分别表示 亮度、对比度、饱和度、色温
        transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),  # 将PIL图像或NumPy ndarray转换为tensor类型,并将像素值的范围从[0, 255]缩放到[0.0, 1.0],默认把通道维度放在前面
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 给定均值和标准差对图像进行标准化,前者为均值,后者为标准差,三个值表示三通道图像

    ]),
    'valid':  # 验证集
        transforms.Compose([  # 整合图像处理的操作
            transforms.Resize([256, 256]),  # 缩放图像尺寸
            transforms.ToTensor(),  # 转换为torch类型
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化
        ])
}

关键点说明

  • 训练集使用了丰富的数据增强来防止过拟合
  • 验证集只进行必要的尺寸调整和归一化
  • 使用ImageNet的均值和标准差进行归一化

2.2 自定义数据集类

我们创建了food_dataset类来管理数据:

class food_dataset(Dataset):  # food_dataset是自己创建的类名称,继承Dataset类
    def __init__(self, file_path, transform=None):  # 类的初始化,解析数据文件txt,file_path表示文件路径,transform可选的图像转换操作
        self.file_path = file_path  # 将文件地址传入self空间
        self.imgs = []
        self.labels = []
        self.transform = transform  # 将数据增强操作传入self空间
        with open(self.file_path) as f:  # 打开存放图片地址及其类别的文本文件train.txt,
            samples = [x.strip().split(' ') for x in f.readlines()]  # 遍历文件里的每一条数据,经过处理后存入sample列表,元祖的形式存放
            for img_path, label in samples:  # 遍历列表中的每个元组的每个元素
                self.imgs.append(img_path)  # 将图像的路径存入img列表
                self.labels.append(label)  # 将图片类别标签存入label列表

    # 初始化:把图片目录加载到self.

    def __len__(self):  # 类实例化对象后,可以使用len函数测量对象的个数
        return len(self.imgs)  # 返回数据集中样本的总数

    def __getitem__(self, idx):  # 关键,可通过索引idx的形式获取每一个图片数据及标签
        image = Image.open(self.imgs[idx])  # 使用PIL库中的用法Image打开并识别图像,还不是tensor
        if self.transform:  # 判断是否有图像转换操作,上述定义默认为None,有则将pil图像数据转换为tensor类型
            image = self.transform(image)  # 图像处理为256*256,转换为tenor

        label = self.labels[idx]  # label还不是tensor
        label = torch.from_numpy(
            np.array(label, dtype=np.int64))  # 首先指定标签类型为int型,然后将其转换为numpy数组类型,然后再使用torch.from_numpy转换为torch类型
        return image, label  # 返回处理完的图片和标签


关键方法

  • __init__: 从文本文件加载图像路径和标签
  • __len__: 返回数据集大小
  • __getitem__: 按索引返回图像和标签

3. 模型架构设计

我们构建了一个三层的CNN模型:

class CNN(nn.Module):
    def __init__(self):  # 翰入大小 (3,256,256)
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # 将多个层组合成一起。
            nn.Conv2d(  # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序
                in_channels=3,  # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数)
                out_channels=16,  # 要得到几多少个特征图,卷积核的个数.
                kernel_size=5,  # 卷积核大小,5*5
                stride=1,  # 步长
                padding=2,  # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好。那p
            ),  # 输出的特征图为 (16,256,256)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域),输出结果为:(16,128,128)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),  # 输出(32,128,128)
            nn.ReLU(),
            nn.MaxPool2d(2)  # 输出
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 128, 5, 1, 2),
            nn.ReLU(),
        )
        self.out = nn.Linear(128 * 64 * 64, 20)  # 全连接

    def forward(self, x):  # 前向传播
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)  # 输出(64,128,64,64)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output  # 返回输出结果

架构特点

  1. 使用nn.Sequential组织网络层
  2. 每层包含卷积、ReLU激活和池化
  3. 最后一层全连接输出20个类别的概率

4. 模型训练与验证

4.1 训练流程

def train(dataloader, model, loss_fn, optimizer):  # 传入参数 打包的数据,卷积模型,损失函数,优化器
    model.train()  # 表示模型开始训练
    batch_size_num = 1
    for x, y in dataloader:  # 遍历打包的图片及其对应的标签,其中batch为每一个数据的编号
        x, y = x.to(device), y.to(device)  # 把训练数据集和标签传入cpu或GPU
        pred = model.forward(x)  # 自动初始化 W权值
        loss = loss_fn(pred, y)  # 传入模型训练结果的预测值和真实值,通过交叉熵损失函数计算损失值L0

        optimizer.zero_grad()  # 梯度值清零
        loss.backward()  # 反向传播计算得到每个参数的梯度
        optimizer.step()  # 根据梯度更新网络参数

        loss = loss.item()  # 获取损失值
        if batch_size_num % 100 == 0:
            print(f"loss: {loss:>7f}[number:{batch_size_num}]")  # 打印损失值,右对齐,长度为7
        batch_size_num += 1  # 右下方传入的参数,表示训练轮数

4.2 验证流程

def test(dataloader, model, loss_fn):  # 定义一个test函数,用于测试模型性能
    global best_acc  # 定义一个全局变量
    size = len(dataloader.dataset)  # 返回打包的图片总数
    num_batches = len(dataloader)  # 返回打包的包的个数
    model.eval()  # 表示模型进入测试模式
    test_loss, correct = 0, 0  # 初始化两个值,一个用来存放总体损失值,一个存放预测准确的个数
    with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()时可以减少
        for x, y in dataloader:  # 遍历数据加载器中测试集图片的图片及其标签
            x, y = x.to(device), y.to(device)  # 传入GPU
            pred = model.forward(x)  # 前向传播,返回预测结果
            test_loss += loss_fn(pred, y).item()  # 计算所有的损失值的和,item表示将tensor类型值转化为python标量
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # 判断预测的值是等于真实值,返回布尔值,将其转换为0和1,然后求和
            # a = (pred.argmax(1)== y)  dim=1表示每一行中的最大值对应的索引号,dim=日表示每 b=(pred.argmax(1)==y).type(torch.float)

        test_loss /= num_batches  # 总体损失值除以数据条数得到平均损失值
        correct /= size  # 求准确率
        print(f"Test result:in Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")  # 表示准确率机器对应的损失值
        # acc_s.append(correct)
        # loss_s.append(test_loss)


### 4.3 训练配置

```python
# 初始化
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 数据加载
#training_data包含了本次需要训练的全部数据集
training_data = food_dataset(file_path=r'D:\Users\妄生\PycharmProjects\人工智能\深度学习\train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'D:\Users\妄生\PycharmProjects\人工智能\深度学习\test.txt', transform=data_transforms['valid'])

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)

# 训练循环
epochs = 150  # 设置模型训练的轮数,不停更新模型参数,找到最优值
acc_s = []  # 初始化了两个空列表,用于存储模型在每个epoch结束时的准确率和损失值
loss_s = []
for t in range(epochs):  # 遍历轮数
    print(f"Epoch {t + 1}\n---------------------------")  # 表示轮数展示
    train(train_dataloader, model, loss_fn, optimizer)  # 调用函数train传入训练集数据加载器、初始化的模型、损失函数、优化器
    test(test_dataloader, model, loss_fn) 

运行结果

在这里插入图片描述

5. 模型保存与加载

5.1 保存模型

我们提供了两种保存方式:

# 方法1:仅保存模型参数(推荐)
torch.save(model.state_dict(), 'best.pth')

# 方法2:保存整个模型
torch.save(model, 'best.pt')

5.2 加载模型

对应两种加载方式:

# 方法1:加载参数
model = CNN().to(device)
model.load_state_dict(torch.load('best.pth'))

# 方法2:加载完整模型
model = torch.load('best.pt')

6. 模型测试与结果分析

我们实现了详细的测试函数:

def test_true(dataloader, model):
    correct = 0  # 正确预测的数量
    total = 0  # 总样本数量
    with torch.no_grad():  # 上下文管理器,关闭梯度运算
        for x, y in dataloader:  # 遍历打包好的图片及其标签
            x, y = x.to(device), y.to(device)  # 将其传入GPU
            pred = model.forward(x)  # 前向传播
            _, predicted = torch.max(pred, 1)  # 获取预测值的类别索引
            total += y.size(0)  # 累加总样本数量
            correct += (predicted == y).sum().item()  # 累加正确预测的数量

            result.append(predicted.item())  # 将预测值的结果转换成Python变量然后增加到列表
            labels.append(y.item())  # 同时将真实值的标签转变成Python标量然后存入labels列表

    accuracy = correct / total  # 计算准确率
    print(f'准确率: {accuracy:.4f}')  # 打印准确率

# 调用测试函数
test_true(test_dataloader, model)  # 导入数据和模型
print('预测值:\t', result)
print('真实值:\t', labels)

运行结果

在这里插入图片描述

7. 总结

本文详细介绍了基于PyTorch的食品图像分类系统的完整实现流程,从数据准备到模型部署。该系统具有以下优势:

  1. 高效的数据处理:完善的数据增强和加载机制
  2. 可靠的模型架构:经过优化的CNN结构
  3. 完整的训练流程:包含训练、验证和测试
  4. 灵活的部署方案:提供多种模型保存方式

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2340533.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

v-html 显示富文本内容

返回数据格式&#xff1a; 只有图片名称 显示不出完整路径 解决方法&#xff1a;在接收数据后手动给img格式的拼接vite.config中的服务器地址 页面&#xff1a; <el-button click"">获取信息<el-button><!-- 弹出层 --> <el-dialog v-model&…

【数学建模】孤立森林算法:异常检测的高效利器

孤立森林算法&#xff1a;异常检测的高效利器 文章目录 孤立森林算法&#xff1a;异常检测的高效利器1 引言2 孤立森林算法原理2.1 核心思想2.2 算法流程步骤一&#xff1a;构建孤立树(iTree)步骤二&#xff1a;构建孤立森林(iForest)步骤三&#xff1a;计算异常分数 3 代码实现…

<项目代码>YOLO小船识别<目标检测>

项目代码下载链接 YOLOv8是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题&#xff0c;能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法&#xff08;如Faster R-CNN&#xff09;&#xff0…

Crawl4AI:打破数据孤岛,开启大语言模型的实时智能新时代

当大语言模型遇见数据饥渴症 在人工智能的竞技场上&#xff0c;大语言模型&#xff08;LLMs&#xff09;正以惊人的速度进化&#xff0c;但其认知能力的跃升始终面临一个根本性挑战——如何持续获取新鲜、结构化、高相关性的数据。传统数据供给方式如同输血式营养支持&#xff…

【Spring Boot】MyBatis多表查询的操作:注解和XML实现SQL语句

1.准备工作 1.1创建数据库 &#xff08;1&#xff09;创建数据库&#xff1a; CREATE DATABASE mybatis_test DEFAULT CHARACTER SET utf8mb4;&#xff08;2&#xff09;使用数据库 -- 使⽤数据数据 USE mybatis_test;1.2 创建用户表和实体类 创建用户表 -- 创建表[⽤⼾表…

[Android]豆包爱学v4.5.0小学到研究生 题目Ai解析

拍照解析答案 【应用名称】豆包爱学 【应用版本】4.5.0 【软件大小】95mb 【适用平台】安卓 【应用简介】豆包爱学&#xff0c;一般又称河马爱学教育平台app,河马爱学。 关于学习&#xff0c;你可能也需要一个“豆包爱学”这样的AI伙伴&#xff0c;它将为你提供全方位的学习帮助…

Qt开发:软件崩溃时,如何生成dump文件

文章目录 一、程序崩溃时如何自动生成 Dump 文件二、支持多线程中的异常捕获三、在 DLL 中使用 Dump 捕获四、封装成可复用类五、MiniDumpWriteDump函数详解 一、程序崩溃时如何自动生成 Dump 文件 步骤一&#xff1a;包含必要的头文件 #include <Windows.h> #include …

普罗米修斯Prometheus监控安装(mac)

普罗米修斯是后端数据监控平台&#xff0c;通过Node_exporter/mysql_exporter等收集数据&#xff0c;Grafana将数据用图形的方式展示出来 官网各平台下载 Prometheus安装&#xff08;mac&#xff09; &#xff08;1&#xff09;通过brew安装 brew install prometheus &…

Python SQL 工具包:SQLAlchemy介绍

SQLAlchemy 是一个功能强大且灵活的 Python SQL 工具包和对象关系映射&#xff08;ORM&#xff09;库。它被广泛用于与关系型数据库进行交互&#xff0c;提供了从低级 SQL 表达式到高级 ORM 的完整工具链。SQLAlchemy 的设计目标是让开发者能够以 Pythonic 的方式操作数据库&am…

Shader属性讲解+Cg语言讲解

CPU调用GPU传递数据 修改Render组件的material属性 在脚本中更改游戏物体材质颜色代码示例&#xff1a; using System.Collections; using System.Collections.Generic; using UnityEngine;public class TestFixedColor : MonoBehaviour {void Start(){//创建预制体GameObjec…

基于LightGBM-TPE算法对交通事故严重程度的分析与可视化

基于LightGBM-TPE算法对交通事故严重程度的分析与可视化 原文&#xff1a; Analysis and visualization of accidents severity based on LightGBM-TPE 1. 引言部分 文章开篇强调了道路交通事故作为意外死亡的主要原因&#xff0c;引起了多学科领域的关注。分析事故严重性特…

什么是CRM系统,它的作用是什么?CRM全面指南

CRM&#xff08;Customer Relationship Management&#xff0c;客户关系管理&#xff09;系统是一种专门用于集中管理客户信息、优化销售流程、提升客户满意度、支持精准营销、驱动数据分析决策、加强跨部门协同、提升客户生命周期价值的业务系统工具。其中&#xff0c;优化销售…

MYSQL之库的操作

创建数据库 语法很简单, 主要是看看选项(与编码相关的): CREATE DATABASE [IF NOT EXISTS] db_name [create_specification [, create_specification] ...] create_specification: [DEFAULT] CHARACTER SET charset_name [DEFAULT] COLLATE collation_name 1. 语句中大写的是…

Linux 下的网络管理(附加详细实验案例)

一、简单了解 NM&#xff08;NetworkManager&#xff09; 在 Linux 中&#xff0c;NM 是 NetworkManager 的缩写。它是一个用于管理网络连接的守护进程和工具集。 在 RHEL9 上&#xff0c;使用 NM 进行网络配置&#xff0c;ifcfg &#xff08;也称为文件&#xff09;将不再…

基于SpringBoot的疫情居家检测管理系统(源码+数据库)

514基于SpringBoot的疫情居家检测管理系统&#xff0c;系统包含三种角色&#xff1a;管理员、用户、医生&#xff0c;主要功能如下。 【用户功能】 1. 首页&#xff1a;获取系统信息。 2. 论坛&#xff1a;参与居民讨论和分享信息。 3. 公告&#xff1a;查看社区发布的各类公告…

MATLAB 控制系统设计与仿真 - 35

MATLAB鲁棒控制器分析 所谓鲁棒性是指控制系统在一定(结构&#xff0c;大小)的参数扰动下&#xff0c;维持某些性能的特征。 根据对性能的不同定义&#xff0c;可分为稳定鲁棒性(Robust stability)和性能鲁棒性(Robust performance)。 以闭环系统的鲁棒性作为目标设计得到的…

性能比拼: Nginx vs Caddy

本内容是对知名性能评测博主 Anton Putra Nginx vs Caddy Performance 内容的翻译与整理, 有适当删减, 相关指标和结论以原作为准 引言 在本期视频中&#xff0c;我们将对比 Nginx 和 Caddy---一个用 Go 编写的 Web 服务器和反向代理。 在第一个测试中&#xff0c;我们会使用…

C++项目-衡码云判项目演示

衡码云判项目是什么呢&#xff1f;简单来说就是这是一个类似于牛客、力扣等在线OJ系统&#xff0c;用户在网页编写代码&#xff0c;点击提交后传递给后端云服务器&#xff0c;云服务器将用户的代码和测试用例进行合并编译&#xff0c;返回结果到网页。 项目最大的两个亮点&…

李宏毅NLP-6-seq2seqHMM

比较seq2seq和HMM Hidden Markov Model(HMM) 隐马尔可夫模型&#xff08;HMM&#xff09;在语音识别中的应用&#xff0c;具体内容如下&#xff1a; 整体流程&#xff1a; 左侧为语音信号&#xff08;标记为 “speech”&#xff09;&#xff0c;其特征表示为 X X X。中间蓝色模…

百度暑期实习岗位超3000个,AI相关岗位占比87%,近屿智能携AIGC课程加速人才输出

今年3月&#xff0c;百度重磅发布3000暑期实习岗位&#xff0c;聚焦大模型、机器学习、自动驾驶等AI方向的岗位比例高达87%。此次实习岗位涉及技术研发、产品策划、专业服务、管理支持、政企解决方案等四大类别&#xff0c;覆盖超300个岗位细分方向。值得一提的是&#xff0c;百…