基于Pytorch框架的深度学习EfficientNet神经网络香蕉水果成熟度识别分类系统源码

news2024/11/15 19:24:45

 第一步:准备数据

4种香蕉水果成熟度数据:overripe,ripe,rotten,unripe(过熟、熟、烂、未成熟),总共有13474张图片,每个文件夹单独放一种成熟度数据

第二步:搭建模型

本文选择一个EfficientNet网络,其原理介绍如下:

        为了弄清楚神经网络缩放之后的效果,谷歌团队系统地研究了改变不同维度对模型的影响,维度参数包括网络深度、宽度和图像分辨率。首先他们进行了栅格搜索(Grid Search)。这是一种穷举搜索方法,可以在固定资源的限定下,列出所有参数之间的关系,显示出改变某一种维度时,基线网络模型会受到什么样的影响。换句话说,如果只改变了宽度、深度或分辨率,模型的表现会发生什么变化。

        综合考虑所有情况之后,他们确定了每个维度最合适的调整系数,然后将它们一同应用到基线网络中,对每个维度都进行适当的缩放,并且确保其符合目标模型的大小和计算预算。

        简单来说,就是分别找到宽度、深度和分辨率的最佳系数,然后将它们组合起来一起放入原本的网络模型中,对每一个维度都有所调整。从整体的角度缩放模型。与传统方法相比,这种复合缩放法可以持续提高模型的准确性和效率。在现有模型 MobileNet 和 ResNet 上的测试结果显示,它分别提高了 1.4% 和 0.7% 的准确率。

         因为,为了进一步提高性能,谷歌 AI 团队还使用了 AutoML MNAS 框架进行神经架构搜索,优化准确性和效率。AutoML 是一种可以自动设计神经网络的技术,由谷歌团队在 2017 年提出,而且经过了多次优化更新。使用这种技术可以更简便地创造神经网络。由此产生的架构使用了移动倒置瓶颈卷积(MBConv),类似于 MobileNetV2 和 MnasNet 模型,但由于计算力(FLOPS)预算增加,MBConv 模型体积略大。随后他们多次缩放了基线网络,组成了一系列模型,统称为 EfficientNets。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import efficientnet_b0 as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = {"B0": 224,
                "B1": 240,
                "B2": 260,
                "B3": 300,
                "B4": 380,
                "B5": 456,
                "B6": 528,
                "B7": 600}
    num_model = "B0"

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(img_size[num_model]),
                                   transforms.CenterCrop(img_size[num_model]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = create_model(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后一个卷积层和全连接层外,其他权重全部冻结
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)
        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        tags = ["loss", "accuracy", "learning_rate"]
        tb_writer.add_scalar(tags[0], mean_loss, epoch)
        tb_writer.add_scalar(tags[1], acc, epoch)
        tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.01)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default=r"G:\demo\data\classifier\classifier\train")

    # download model weights
    # 链接: https://pan.baidu.com/s/1ouX0UmjCsmSx3ZrqXbowjw  密码: 090i
    parser.add_argument('--weights', type=str, default='./efficientnetb0.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习CNN神经网络香蕉水果成熟度识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1698342.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

零基础小白可以做抖音电商吗?小白做电商难度大吗?一篇全解!

大家好,我是电商花花 在直播电商的热度越来越多,更多普通的创业者都对抖音小店电商有了想法,因为很多普通 人都通过抖音小店开店卖货赚到了钱,让更多人对抖店电商产生了兴趣。 于是做抖音小店无货源,开店卖货赚钱成为…

【软件推荐】obsidian设置

【软件推荐】obsidian设置 初始化 附件相对路径设置 打开obsidian 设置-文件与链接,找到下图的这几个设置。设置为如图所示。 插件推荐 实时渲染 你可能会想,obsidian的使用体验没有typora好呀! typora可以实时渲染,obsid…

用c++用4个凸函数(觉得啥好用用啥)去测试adam,rmsprop,adagrad算法的性能(谁先找到最优点)

为了测试 Adam、RMSProp 和 Adagrad 算法的性能,你可以使用四个凸函数进行实验。以下是一些常用的凸函数示例: Rosenbrock 函数: Booth 函数: Himmelblau 函数: Beale 函数: 你可以选择其中一个或多…

光线追踪技术在AI去衣中的革命性角色

引言: 随着人工智能和计算机图形学的飞速发展,AI去衣技术已经从理论走向实践,为影视制作、虚拟现实、在线试衣等领域提供了强大的技术支持。在这一过程中,光线追踪技术以其卓越的渲染能力和逼真的光影效果,成为AI去衣领…

C++开发面试常问总结

一些面试总结 TCP粘包了解吗?解决办法?讲一下乐观锁悲观锁git中 git pull和git fetch的区别1.虚函数实现机制:2.进程和线程的区别:3.TCP三次握手、四次挥手:4.HTTP状态码,报头:5.智能指针&#…

MySql基础(一)--最详细基础入门,看完就懂啦(辛苦整理,想要宝宝的赞和关注嘻嘻)

前言 希望你向太阳一样,有起有落,不失光彩~ 一、数据库概述 1. 什么是数据库 数据库就是存储数据的仓库,其本质是一个文件系统,数据按照特定的格式将数据存储起来,用户可以对数据库中的数据进行增加,修改&…

【Django项目】 音乐网站spotify复刻

代码:https://github.com/tomitokko/spotify-clone 注:该项目不是自己提供mp3文件,而是使用spotify 的api接口获取。

奇舞周刊第529期:万字长文入门前端全球化

周五快乐(图片由midjourney生成) 奇舞推荐 ■ ■ ■ 万字长文入门前端全球化 目前国内企业正积极开拓国际市场,国际化已成为重要的发展方向,因此产品设计和开发更需考虑国际化。本文介绍了语言标识、文字阅读顺序等诸多知识。然后…

【编译原理复习笔记】中间语言

中间语言 中间语言的特点和作用 (1)独立于机器 (2)复杂性介于源语言和目标语言之间 中间语言可以使编译程序的结构在逻辑上更为简单明确 常用的中间语言 后缀式 图表示:抽象语法树,有向无环图 三地址代…

淘宝x5sec

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!wx a15018601872 本文章未…

LeetCode 264 —— 丑数 II

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 第一个丑数是 1 1 1,由于丑数的质因子只包含 2 、 3 、 5 2、3、5 2、3、5,所以后面的丑数肯定是前面的丑数分别乘以 2 、 3 、 5 2、3、5 2、3、5 后得到的数字。 这样,我…

类的内存对齐位段位图布隆过滤器哈希切割一致性哈希

文章目录 一、类的内存对齐1.1规则1.2原因 二、位段2.1介绍2.2内存分配问题2.3跨平台问题2.4使用的注意事项 三、位图的应用3.1 给40亿个不重复的无符号整数,找给定的一个数。(int的范围可以到达42亿多)3.2 给定100亿个整数,设计算…

openEuler 22.03 LTS SP3源码编译部署OpenStack-Caracal

openEuler 22.03 LTS SP3源码编译部署OpenStack-Caracal 说明机器详情安装操作系统注意事项基础准备Controller节点 && Compute节点 && Block节点关闭防火墙关闭selinux设置静态IP更新安装前准备Controller节点 && Compute节点 && Block节点设…

auto关键字(C++11)

auto关键字(C11) 文章目录 auto关键字(C11)前言一、auto使用规则二、auto不适用的场景三、auto推荐适用的场景总结 前言 在C11中,auto关键字能够自动推导出变量的实际类型,可以帮助我们写出更加简洁、现代…

开发公众号自定义菜单之创建菜单

文章目录 申请测试账号换取Token接口测试提交自定义菜单查看效果校验菜单配置清空菜单配置结束语 申请测试账号 https://mp.weixin.qq.com/debug/cgi-bin/sandboxinfo?actionshowinfo&tsandbox/index 或 得到appid和secret 换取Token 使用appid和secret换取token令牌…

Python应用实战,用动画生成冒泡排序的过程

写在前言 hello,大家好,我是一点,专注于Python编程,如果你也对感Python感兴趣,欢迎关注交流。 希望可以持续更新一些有意思的文章,如果觉得还不错,欢迎点赞关注,有啥想说的&#x…

解决文件夹打开出错问题:原因、数据恢复与预防措施

在我们日常使用电脑或移动设备时,有时会遇到一个非常棘手的问题——文件夹打开出错。这种错误可能会让您无法访问重要的文件和数据,给工作和生活带来极大的不便。本文将带您深入了解文件夹打开出错的原因,并提供有效的数据恢复方案&#xff0…

栈和队列的基本见解

1.栈 1.1栈的基本概念和结构: 栈是一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶,另一端称为栈底。栈中的数据元素遵守后进先出的原则。 压栈:栈的插入操作叫做进栈/压栈…

Vxe UI 表单设计器、零代码平台

vxe-pc-ui Vxe UI 表单设计器、零代码表单设计器 安装 Vxe UI PC端组件库 官方文档 查看 github、gitee // ...import VxeUI from vxe-pc-uiimport vxe-pc-ui/lib/style.css// ...// ...createApp(App).use(VxeUI).mount(#app)// ...使用 vxe-form-design 设计器组件 vxe-fo…

QML的Image 路径问题(source)

四种路径格式 在 QML 中,当你使用 Image 元素的 source 属性来指定一个图片的路径时,有几种不同的方式可以指定这个路径,每种方式都有其特定的用途和上下文。 相对路径: QML 文件和一个名为 close.png 的图片在同一目录下&#x…