昇思25天学习打卡营第22天 | 基于MobileNetv2的垃圾分类函数式自动微分

news2025/1/11 1:51:13

基于MobileNetV2的垃圾分类

在这里插入图片描述

本文档详细介绍了使用MobileNetV2模型进行垃圾分类的全过程,包括数据准备、模型搭建、模型训练、评估和推理等步骤。MobileNetV2是一种轻量级卷积神经网络,专为移动端和嵌入式设备设计,具有高效、低耗的特点。通过将该模型应用于垃圾分类任务,我们可以自动识别和分类不同类型的垃圾,提高垃圾处理的效率。

本文档介绍了使用MobileNetV2模型进行垃圾分类的代码开发过程。我们将通过读取本地图像数据作为输入,对图像中的垃圾物体进行检测,并将检测结果图片保存到文件中。

1. 实验目的
  • 熟悉垃圾分类应用代码的编写(Python语言)。
  • 了解Linux操作系统的基本使用。
  • 掌握atc命令进行模型转换的基本操作。
2. MobileNetV2模型原理介绍

MobileNetV2是Google团队于2018年提出的一种轻量级卷积神经网络,专注于移动端、嵌入式或IoT设备。相比传统的卷积神经网络,MobileNetV2使用深度可分离卷积(Depthwise Separable Convolution),在准确率小幅度降低的前提下,大大减小了模型参数与运算量。

MobileNetV2通过引入倒残差结构(Inverted Residual Block)和线性瓶颈(Linear Bottlenecks)来设计网络,以提高模型的准确率,同时优化后的模型更小。

3. 实验环境

本案例支持Win_x86和Linux系统,CPU/GPU/Ascend均可运行。

4. 数据处理
4.1 数据准备

MobileNetV2的代码默认使用ImageFolder格式管理数据集。每类图片整理成单独的一个文件夹,数据集结构如下:

└─ImageFolder
  ├─train
  │   ├─class1Folder
  │   └─......
  └─eval
      ├─class1Folder
      └─......
4.2 数据加载
import math
import numpy as np
import os
import random
from matplotlib import pyplot as plt
from easydict import EasyDict
from PIL import Image
import mindspore.nn as nn
import mindspore.dataset as de
import mindspore.dataset.vision as C
import mindspore.dataset.transforms as C2
import mindspore as ms
from mindspore import set_context, Tensor
from mindspore.train import Model
from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig

os.environ['GLOG_v'] = '3'
os.environ['GLOG_logtostderr'] = '0'
os.environ['GLOG_log_dir'] = '../../log'
os.environ['GLOG_stderrthreshold'] = '2'
set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0)

# 数据集标签和字典
garbage_classes = {
    '干垃圾': ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服'],
    '可回收物': ['报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张'],
    '湿垃圾': ['菜叶', '橙皮', '蛋壳', '香蕉皮'],
    '有害垃圾': ['电池', '药片胶囊', '荧光灯', '油漆桶']
}

class_cn = ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服',
            '报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张',
            '菜叶', '橙皮', '蛋壳', '香蕉皮',
            '电池', '药片胶囊', '荧光灯', '油漆桶']
class_en = ['Seashell', 'Lighter','Old Mirror', 'Broom','Ceramic Bowl', 'Toothbrush','Disposable Chopsticks','Dirty Cloth',
            'Newspaper', 'Glassware', 'Basketball', 'Plastic Bottle', 'Cardboard','Glass Bottle', 'Metalware', 'Hats', 'Cans', 'Paper',
            'Vegetable Leaf','Orange Peel', 'Eggshell','Banana Peel',
            'Battery', 'Tablet capsules','Fluorescent lamp', 'Paint bucket']

index_en = {'Seashell': 0, 'Lighter': 1, 'Old Mirror': 2, 'Broom': 3, 'Ceramic Bowl': 4, 'Toothbrush': 5, 'Disposable Chopsticks': 6, 'Dirty Cloth': 7,
            'Newspaper': 8, 'Glassware': 9, 'Basketball': 10, 'Plastic Bottle': 11, 'Cardboard': 12, 'Glass Bottle': 13, 'Metalware': 14, 'Hats': 15, 'Cans': 16, 'Paper': 17,
            'Vegetable Leaf': 18, 'Orange Peel': 19, 'Eggshell': 20, 'Banana Peel': 21,
            'Battery': 22, 'Tablet capsules': 23, 'Fluorescent lamp': 24, 'Paint bucket': 25}

# 训练超参
config = EasyDict({
    "num_classes": 26,
    "image_height": 224,
    "image_width": 224,
    "backbone_out_channels":1280,
    "batch_size": 16,
    "eval_batch_size": 8,
    "epochs": 10,
    "lr_max": 0.05,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "save_ckpt_epochs": 1,
    "dataset_path": "./data_en",
    "class_index": index_en,
    "pretrained_ckpt": "./mobilenetV2-200_1067.ckpt"
})

def create_dataset(dataset_path, config, training=True, buffer_size=1000):
    """
    创建训练或评估数据集

    Args:
        dataset_path (string): 数据集路径。
        config (struct): 训练和评估配置。

    Returns:
        ds (dataset): 返回训练或评估数据集。
    """
    data_path = os.path.join(dataset_path, 'train' if training else 'test')
    ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)
    resize_height = config.image_height
    resize_width = config.image_width
    
    normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
    change_swap_op = C.HWC2CHW()
    type_cast_op = C2.TypeCast(mstype.int32)

    if training:
        crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
        horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
        color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
    
        train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]
        train_ds = ds.map(input_columns="image", operations=train_trans, num_parallel_workers=4)
        train_ds = train_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)
        train_ds = train_ds.shuffle(buffer_size=buffer_size)
        ds = train_ds.batch(config.batch_size, drop_remainder=True)
    else:
        decode_op = C.Decode()
        resize_op = C.Resize((int(resize_width/0.875), int(resize_width/0.875)))
        center_crop = C.CenterCrop(resize_width)
        
        eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
        eval_ds = ds.map(input_columns="image", operations=eval_trans, num_parallel_workers=4)
        eval_ds = eval_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)
        ds = eval_ds.batch(config.eval_batch_size, drop_remainder=True)
    return ds

# 展示部分处理后的数据
ds = create_dataset(dataset_path=config.dataset_path, config=config, training=False)
print(ds.get_dataset_size())
data = ds.create_dict_iterator(output_numpy=True)._get_next()
images = data['image']
labels = data['label']

for i in range(1, 5):
    plt.subplot(2, 2, i)
    plt.imshow(np.transpose(images[i], (1,2,0)))
    plt.title('label: %s' % class_en[labels[i]])
    plt.xticks([])
plt.show()
5. MobileNetV2模型搭建

使用MindSpore定义MobileNetV2网络的各模块时需要继承mindspore.nn.Cell。Cell是所有神经网络(如Conv2d等)的基类。以下是MobileNetV2模型的定义:

__all__ = ['

mobilenet_v2']

def conv_bn(inp, oup, stride):
    return nn.SequentialCell([
        nn.Conv2d(inp, oup, 3, stride, pad_mode='pad', padding=1, has_bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6()
    ])

def conv_1x1_bn(inp, oup):
    return nn.SequentialCell([
        nn.Conv2d(inp, oup, 1, 1, pad_mode='pad', has_bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6()
    ])

class InvertedResidual(nn.Cell):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            layers.append(conv_1x1_bn(inp, hidden_dim))
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, pad_mode='pad', padding=1, group=hidden_dim, has_bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(),
            nn.Conv2d(hidden_dim, oup, 1, 1, pad_mode='pad', has_bias=False),
            nn.BatchNorm2d(oup)
        ])
        self.conv = nn.SequentialCell(layers)

    def construct(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileNetV2(nn.Cell):
    def __init__(self, num_classes=1000, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        assert len(interverted_residual_setting[0]) == 4

        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * max(1.0, width_mult))
        self.features = [conv_bn(3, input_channel, 2)]
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        self.features.append(nn.AvgPool2d(7))
        self.features = nn.SequentialCell(self.features)
        self.classifier = nn.SequentialCell([
            nn.Dropout(0.2),
            nn.Dense(self.last_channel, num_classes),
        ])
        self._initialize_weights()

    def construct(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for _, m in self.cells_and_names():
            if isinstance(m, nn.Conv2d):
                ms.common.initializer.XavierUniform(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.gamma.set_data(ms.common.initializer.One())
                m.beta.set_data(ms.common.initializer.Zero())
            elif isinstance(m, nn.Dense):
                m.weight.set_data(ms.common.initializer.Normal(0.01))
                if m.bias is not None:
                    m.bias.set_data(ms.common.initializer.Zero())

def mobilenet_v2(pretrained=False, **kwargs):
    model = MobileNetV2(**kwargs)
    return model

# 创建MobileNetV2模型
network = mobilenet_v2(num_classes=config.num_classes)

# 加载预训练模型
param_dict = ms.load_checkpoint(config.pretrained_ckpt)
ms.load_param_into_net(network, param_dict)
print("load pretrained mobilenet_v2 from [{}]".format(config.pretrained_ckpt))
6. 模型训练

模型训练阶段定义如下:

def init_lr(step_size):
    lr_max = config.lr_max
    total_steps = config.epochs * step_size
    warmup_steps = int(0.1 * total_steps)
    lr_each_step = []
    for i in range(total_steps):
        if i < warmup_steps:
            lr = lr_max * (i + 1) / warmup_steps
        else:
            lr = lr_max * (0.5 + 0.5 * math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps)))
        lr_each_step.append(lr)
    return np.array(lr_each_step).astype(np.float32)

# 创建训练集
train_dataset = create_dataset(dataset_path=config.dataset_path, config=config, training=True)

# 优化器
lr = init_lr(train_dataset.get_dataset_size())
opt = nn.Momentum(network.trainable_params(), lr, config.momentum, config.weight_decay, use_nesterov=True)

# 损失函数
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# 定义模型
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'})

# 模型保存配置
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_mobilenetV2", directory="./", config=config_ck)

# 损失监控
loss_cb = LossMonitor()

# 训练模型
print("============== Starting Training ==============")
model.train(config.epochs, train_dataset, callbacks=[ckpoint_cb, loss_cb], dataset_sink_mode=True)
7. 模型评估

模型评估阶段代码如下:

# 创建评估集
eval_dataset = create_dataset(dataset_path=config.dataset_path, config=config, training=False)

# 评估模型
acc = model.eval(eval_dataset)
print("============== Acc: {} ==============".format(acc))
8. 模型推理

对于单张图片的推理,可以使用以下代码:

from PIL import Image

def read_img(img_path):
    image = Image.open(img_path).convert('RGB')
    transform = de.transforms.Compose([
        C.Resize((224, 224)),
        C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]),
        C.HWC2CHW()
    ])
    img = transform(image)
    img = np.expand_dims(img, axis=0)
    return img

def infer(img_path):
    img = read_img(img_path)
    img_tensor = Tensor(img)
    output = model.predict(img_tensor)
    pred = np.argmax(output.asnumpy(), axis=1)
    return class_en[pred[0]]

# 读取图像进行推理
img_path = "./data_en/test/Seashell/001.jpg"
pred_label = infer(img_path)
print("Predicted label: ", pred_label)

此代码将在指定的图像文件上执行推理,并输出预测的标签。
在这里插入图片描述
通过本次实验,我收获了以下几点:

数据预处理的重要性:
数据预处理是模型训练的关键一步。通过数据增强(如随机裁剪、水平翻转和颜色调整),我们能够提升模型的泛化能力,减少过拟合的风险。

模型设计与优化:
MobileNetV2的倒残差结构(Inverted Residual Block)和线性瓶颈(Linear Bottlenecks)在保持模型准确率的同时,显著减少了参数量和计算量,展示了优秀的模型设计理念。

训练策略与技巧:
在训练过程中,学习率的设定和调整(如学习率预热和余弦退火策略)对模型的收敛速度和最终性能有很大影响。此外,使用Momentum优化器结合Nesterov动量,可以加速训练过程并提高模型准确率。

模型评估与推理:
通过对模型进行评估,我们可以了解其在测试集上的表现,及时调整训练策略。对于单张图片的推理,通过预处理步骤和模型预测,我们能够准确输出垃圾的类别。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1928440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

昇思25天学习打卡营第21天|RNN实现情感分类

这节课学习的是RNN实现情感分类&#xff0c;情感分类是自然语言处理中的经典任务&#xff0c;是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型。 比如&#xff1a; 输入: This film is terrible 正确标签: Negative 预测标签: Negative 输入: This film…

Matlab进阶绘图第63期—带标记线的三维填充折线图

三维填充折线图是在三维折线图的基础上&#xff0c;对其与XOY平面之间的部分进行颜色填充&#xff0c;从而能够更好地刻画细节变化。 而带标记线的三维填充折线图是在其基础上&#xff0c;添加X相同的一条或多条标记线&#xff0c;以用于进一步讨论分析。 由于Matlab中未收录…

Mongodb复合索引

学习mongodb&#xff0c;体会mongodb的每一个使用细节&#xff0c;欢迎阅读威赞的文章。这是威赞发布的第90篇mongodb技术文章&#xff0c;欢迎浏览本专栏威赞发布的其他文章。如果您认为我的文章对您有帮助或者解决您的问题&#xff0c;欢迎在文章下面点个赞&#xff0c;或者关…

Windows上LabVIEW编译生成可执行程序

LabVIEW项目浏览器(Project Explorer)中的"Build Specifications"就是用来配置项目发布方法的。在"Build Specifications"右键菜单中选取"New"&#xff0c;可以看到程序有几种不同的发布方法&#xff1a;Application(EXE)、Installer、.Net Inte…

C# 基于共享内存实现跨进程队列

C# 进程通信系列 第一章 共享内存 第二章 共享队列&#xff08;本章&#xff09; 文章目录 C# 进程通信系列前言一、实现原理1、用到的主要对象2、创建共享内存3、头部信息4、入队5、出队6、释放资源 二、完整代码三、使用示例1、传输byte[]数据2、传输字符串3、传输对象 总结…

HarmonyOS 屏幕适配设计

1. armonyOS 屏幕适配设计 1.1. 像素单位 &#xff08;1&#xff09;px (Pixels)   px代表屏幕上的像素点&#xff0c;是手机屏幕分辨率的单位&#xff0c;即屏幕物理像素单位。 &#xff08;2&#xff09;vp (Viewport Percentage)   vp是视口百分比单位&#xff0c;基于…

Java学习之SPI、JDBC、SpringFactoriesLoader、Dubbo

概述 SPI&#xff0c;Service Provider Interface&#xff0c;一种服务发现机制&#xff0c;指一些提供给你继承、扩展&#xff0c;完成自定义功能的类、接口或方法。 在SPI机制中&#xff0c;服务提供者为某个接口实现具体的类&#xff0c;而在运行时通过SPI机制&#xff0c…

Facebook未来展望:数字社交平台的进化之路

在信息技术日新月异的时代&#xff0c;社交媒体平台不仅是人们交流沟通的重要工具&#xff0c;更是推动社会进步和变革的重要力量。作为全球最大的社交媒体平台之一&#xff0c;Facebook在过去十多年里&#xff0c;不断创新和发展&#xff0c;改变了数十亿用户的社交方式。展望…

构建企业多层社会传播网络:以AI智能名片S2B2C商城小程序为例

摘要&#xff1a;在数字化转型的浪潮中&#xff0c;企业如何有效构建并优化其社会传播网络&#xff0c;已成为提升市场竞争力、深化用户关系及实现价值转化的关键。本文以AI智能名片S2B2C商城小程序为例&#xff0c;深入探讨如何通过一系列精细化的策略与技术创新&#xff0c;构…

IP地址知识点

一、IP地址组成 把一个IP地址分成两部分&#xff1a;网络号&#xff08;标识了一个局域网&#xff09;主机号&#xff08;标识了一个局域网中的设备&#xff09; 下图是通过一个路由器连接的两个局域网&#xff08;两个相邻的局域网&#xff09;&#xff0c;网络号不相同&…

AI绘画入门实践|Midjourney 的模型版本

模型分类 Midjourney 的模型主要分为2大类&#xff1a; 默认模型&#xff1a;目前包括&#xff1a;V1, V2, V3, V4, V5.0, V5.1, V5.2, V6 NIJI模型&#xff1a;目前包括&#xff1a;NIJI V4, NIJI V5, NIJI V6 模型切换 你在服务器输入框中输入 /settings&#xff1a; 回车后…

Mac电脑清理软件有哪些 MacBooster和CleanMyMac哪个好用 苹果电脑清理垃圾软件推荐 cleanmymac和柠檬清理

对于苹果电脑用户来说&#xff0c;‌选择合适的清理软件可以帮助优化电脑性能&#xff0c;‌释放存储空间&#xff0c;‌并确保系统安全。一款好用的苹果电脑清理软件&#xff0c;能让Mac系统保持良好的运行状态&#xff0c;避免系统和应用程序卡顿的产生。有关Mac电脑清理软件…

什么是MOW,以bitget钱包为例

元描述&#xff1a;MOW凭借其富有创意的故事情节和广阔的潜力在Solana上脱颖而出。本文深入探讨了其独特的概念和光明的未来。 Mouse in a Cats World (MOW)是一个基于Solana区块链的创新meme项目&#xff0c;它重新构想了一个异想天开且赋予权力的故事。在这个奇幻的宇宙中&am…

JuiceFS、Ceph 和 MinIO 结合使用

1. 流程图 将 JuiceFS、Ceph 和 MinIO 结合使用&#xff0c;可以充分利用 Ceph 的分布式存储能力、JuiceFS 的高性能文件系统特性&#xff0c;以及 MinIO 提供的对象存储接口。以下是一个方案&#xff0c;介绍如何配置和部署 JuiceFS 使用 Ceph 作为其底层存储&#xff0c;并通…

非法闯入智能监测摄像机:安全守护的新利器

在当今社会&#xff0c;安全问题愈发受到重视。随着科技的进步&#xff0c;非法闯入智能监测摄像机应运而生&#xff0c;成为保护家庭和财产安全的重要工具。这种摄像机不仅具备监控功能&#xff0c;还集成了智能识别和报警系统&#xff0c;能够在第一时间内检测到潜在的入侵行…

three.js创建基础模型

场景是一个三维空间&#xff0c;是所有物品的容器。可以将其想象成一个空房间&#xff0c;里面可以放置要呈现的物体、相机、光源等。 通过new THREE.Scene()来创建一个新的场景。 /**1. 创建场景 -- 放置物体对象的环境*/ const scene new THREE.Scene();场景只是一个三维的…

JVM(day2)

经典垃圾收集器 Serial收集 使用一个处理器或一条收集线程去完成垃圾收集工作&#xff0c;更重要的是强调在它进行垃圾收集时&#xff0c;必须暂停其他所有工作线程&#xff0c;直到它收集结束。 ParNew收集器 ParNew 收集器除了支持多线程并行收集之外&#xff0c;其他与 …

HTTP背后的故事:理解现代网络如何工作的关键(二)

一.认识请求方法(method) 1.GET方法 请求体中的首行包括&#xff1a;方法&#xff0c;URL&#xff0c;版本号 方法描述的是这次请求&#xff0c;是具体去做什么 GET方法&#xff1a; 1.GET 是最常用的 HTTP 方法. 常用于获取服务器上的某个资源。 2.在浏览器中直接输入 UR…

【实战:Django-Celery-Flower实现异步和定时爬虫及其监控邮件告警】

1 Django中集成方式一&#xff08;通用方案&#xff09; 1.1 把上面的包-复制到djagno项目中 1.2 在views中编写视图函数 1.3 配置路由 1.4 浏览器访问&#xff0c;提交任务 1.5 启动worker执行任务 1.6 查看任务结果 2 Django中集成方式二&#xff08;官方方案&#xff0…

25_Vision Transformer原理详解

1.1 简介 Vision Transformer (ViT) 是一种将Transformer架构从自然语言处理(NLP)领域扩展到计算机视觉(CV)领域的革命性模型&#xff0c;由Google的研究人员在2020年提出。ViT的核心在于证明了Transformer架构不仅在处理序列数据&#xff08;如文本&#xff09;方面非常有效&…