Mobile net V系列详解 理论+实战(2)

news2025/3/13 4:40:04

请添加图片描述

Mobilenet 系列

  • 实践部分
  • 一、数据集介绍
  • 二、模型整体框架
  • 三、模型代码详解
  • 四、总结

实践部分

本章针对实践通过使用pytorch一个实例对这部分内容进行吸收分析。本章节采用的源代码在这里感兴趣的读者可以自行下载操作。

一、数据集介绍

可以看到数据集本身被存放在了三个文件夹下,其主要是花的图片,被分割成了验证集和训练集,模型训练主要就是采用训练集中的数据进行训练,验证集则用来对模型的性能进行测试。
请添加图片描述
为了进一步增强数据集的结构化和规范化,每个图像通常会被放置在代表其类别的文件夹中。这意味着所有同类别的图像会被存放在相同的文件夹里。这样的存放方式不仅使数据集的管理变得简单化,更重要的是,为使用自动化工具提供了便利。例如,图像数据集的这种标准存放形式完美支持了 PyTorch 中的DatasetFolder工具直接进行处理。请添加图片描述
前几章节在实战部分讲述过,可以省却重复编码自定义Dataset类的复杂过程。DatasetFolder工具能够直观地从这种组织形式的数据集中加载图像及其对应标签,大幅简化了数据预处理和加载的步骤。

二、模型整体框架

在深度学习模型的训练和部署过程中,整个工程项目通常围绕着以下三个核心文件进行组织,进而构建起模型的完整架构。这些文件分别负责不同的任务,协同工作以实现模型的训练、评估和应用。

  1. 模型模块(Model Module) - 位于心脏位置的模型模块,负责存放模型的主体架构。它定义了模型的各个层、前向传播逻辑以及计算过程,是整个深度学习任务的基础和核心。

  2. 训练文件(Training ) - 这个脚本文件负责驱动模型的训练过程。它通过调用先前准备好的数据集及模型模块,以特定的训练策略(例如学习率调整、批处理大小选择等)对模型进行训练。该文件通常会包含模型训练、验证过程,并输出训练过程中的性能指标,如损失和准确率等。

  3. 预测模块(Prediction) - 一旦模型被训练并优化到满意的状态,预测模块则负责将这个训练好的模型导入并应用到后续的任务中。无论是用于进一步的分析、应对实时的预测请求,还是集成至更广阔的系统中,预测模块都为模型的实际使用提供了接口。

将围绕这三个文件对整个模型的框架进行展开讲解。
请添加图片描述

三、模型代码详解

首先看下模型所需要的函数部分:

import os # 文件和文件夹提供一系列操作的工具,当前文件中主要用来查找模块文件的路径地址
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim # 优化方法Adam之类的优化算法
from torchvision import transforms, datasets # 数据集操作
from tqdm import tqdm # 进度条
from model_v2 import MobileNetV2 # 编写的模型主题框架文件

接下来看train的主体文件:

def main(): # 主函数在当前文件下直接执行
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 判断下GPU是否有效
    print("using {} device.".format(device)) # 输出下在什么设备上运行的

    batch_size = 16 # 批大小
    epochs = 5 # 全部周期

    data_transform = {
    # 即对打开的图片如何处理再送入模型,数据增强技术 .Compose将做种方式进行整合,可以按照字典的方式进行调取使用
        "train":  transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

transforms.Compose是PyTorch的torchvision.transforms模块中的一个功能,用于组合多个图像变换操作。以下是这一系列变换操作的具体作用解释:

  1. transforms.RandomResizedCrop(224):

    • 这个变换随机地对图像进行裁剪,并将裁剪后的图像缩放到给定的大小(在这个例子中是224x224像素)。这种变换能够在一定程度上减少模型对图像特定部分的依赖,提高模型对于图像位置变化的鲁棒性,常用于数据增强。
  2. transforms.RandomHorizontalFlip():

    • 随机地水平翻转图像。对于每个图像,它有50%的概率被翻转。这种变换能够增加数据的多样性,帮助模型学习到对于水平方向不变性的特征,减少过拟合。
  3. transforms.ToTensor():

    • 将PIL图像或NumPy的ndarray转换为PyTorch的Tensor。这个操作还会自动将图像的数据从0到255的整数映射到0到1的浮点数,标准化图像的数据范围。
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):

    • 对图像进行标准化,即减去均值(mean)后再除以标准差(std)进行归一化。这里的均值[0.485, 0.456, 0.406]和标准差[0.229, 0.224, 0.225]是针对每一个通道的(通常为RGB通道)。这样的归一化有助于加速训练过程,减少模型对原始图像灰度尺度的依赖。
    • 这组特定的均值和标准差来自ImageNet数据集的统计,是很多预训练模型使用的标准化参数。如果你使用这些预训练模型,采用相同的归一化参数可以保持数据的一致性。

训练集合中这一组变换操作首先对图像进行了数据增强(通过随机裁剪和随机水平翻转),然后转换为了模型训练需要的Tensor格式,并且对图像进行了标准化处理,以便用于模型的训练。这些步骤是进行模型训练时常见的图像预处理流程。

测试集合中操作集合:

  1. transforms.Resize(256):

    • 首先对图像进行缩放,使其最短边的长度为256像素。这步是为了保证图像的尺寸一致性,为后续的裁剪操作做准备。
  2. transforms.CenterCrop(224):

    • 接下来执行中心裁剪,从缩放后的图像中裁切出一个大小为224x224像素的中心区域。中心裁剪通常用在验证和测试集的图像预处理中,旨在减少模型对图像边缘部分的依赖,同时保留图像最关键的内容区域。
  3. transforms.ToTensor():

    • 然后将处理过的图像转换为PyTorch Tensor,并自动将数值范围从[0, 255]归一化到[0, 1]。这是为了使图像数据适配PyTorch模型的输入要求。
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):

    • 最后,对图像的每个通道执行标准化操作。具体来说,使用给定的均值([0.485, 0.456, 0.406])和标准差([0.229, 0.224, 0.225])对图像的RGB通道进行标准化。这一步骤是基于ImageNet数据集的图像统计特性,可以进一步提升模型的泛化能力。标准化有助于加速模型训练,提高模型性能。

os.getcwd() 是Python中的一个函数,隶属于os(操作系统)模块。getcwdget current working directory的缩写,这个函数的作用是返回当前工作目录的绝对路径。

在Python程序中,当前工作目录指的是执行当前代码时所在的文件系统目录。

以下是一个简单的使用例子:

import os

# 获取并打印当前工作目录
current_directory = os.getcwd()
print("当前工作目录是:", current_directory)

下述代码找目录,就是找数据集的位置,用来传数据集,由于其为通用代码所以作者为了减少用户修改代码的必要再次进行模型自动调用。
如果你在命令行中运行上述Python脚本,它会打印出从哪个目录运行了Python解释器。了解当前的工作目录对于执行与文件路径操作相关的任务非常有用,比如读取或写入到相对路径的文件。
通过和"…/…"拼接找上两级的菜单作为当前图片的路径信息,如果要运行就自行修改。

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path

使用断言,如果这个路径不存在就报错,确保有数据集

    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

这个功能和pytorch中的另一个模块比较像:
在PyTorch中,ImageFolderDatasetFolder是两个用来加载数据的类,它们确实有相似之处,但也有一些关键区别。详细地解析一下:

相似之处

  • 目的相同:两者都用于加载数据集,特别是那些按文件夹组织的数据集,其中每个文件夹包含一个类别的数据。
  • 简化数据加载:它们提供了简洁的接口来加载数据,减少了编写自定义加载逻辑的需要,通过transforms参数,还可以很方便地对数据进行预处理和增强。

关键区别

  1. 使用场景

    • ImageFolder特别适用于图像数据,它假定数据集是以文件夹方式组织的,其中每个文件夹对应一个类别的图像。它自动将文件夹的名字作为类别的标签。
    • DatasetFolder则更为通用,可以用来加载任何类型的数据,只要数据是按类别组织在不同文件夹中。它允许通过loader参数自定义如何加载数据,这意味着您可以定义加载图像、文本文件或其他类型文件的函数。
  2. 灵活性:#实际上是DatasetFolder的一个图片领域的应用,即在DatasetFolder中要规定如何打开这个数据,则这应用特例则直接内部定义好了,极简化处理

    • ImageFolder内部实际上是DatasetFolder一个具体实现,特化于处理图像文件,并且预设了使用PIL库来加载图像。这使得ImageFolder使用起来更加简单直观,特别是对于图像数据。
    • DatasetFolder提供了更多的自定义选项,比如自定义加载函数(loader)和数据后缀(extensions),从而可以更灵活地加载不同类型的文件数据。

示例

使用ImageFolder加载图像数据:

from torchvision.datasets import ImageFolder
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

dataset = ImageFolder(root='path/to/data', transform=transform)

使用DatasetFolder加载非图像类型的数据集:

from torchvision.datasets import DatasetFolder
from torchvision import transforms
from my_custom_loader import custom_loader_function

dataset = DatasetFolder(root='path/to/data', loader=custom_loader_function, extensions=('txt',), transform=some_transforms)

总之,虽然ImageFolderDatasetFolder有相似之处,它们都提供了用于加载和处理以文件夹为单位组织的数据集的便捷方法,但DatasetFolder的设计更为通用,提供了更大的灵活性,而ImageFolder则专门用于处理图像数据,使用起来更加方便简洁。

    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train") 
                                           transform=data_transform["train"])
    train_num = len(train_dataset) # 判断下数据集的长度
#获取属性到类别的映射
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4) #将python对象编码成Json字符串 indent:参数根据数据格式缩进显示,读起来更加清晰。
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
# 具体流程就是通过使用class_to_idx得到索引映射信息,使用for进行辩论获取。反转位置将文件写入一个json字符串中,并创建一个文件夹对这部分数据进行保存。
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers 线程数量 计算单个批次损失你多个size就可以一起运行,多个size在不同的核上使用相同的模型计算,得到损失更新参数。
    print('Using {} dataloader workers every process'.format(nw)) # 输出最终决定使用的线程数量

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
                                               # 创建加载器。迭代数据集

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # create model
    net = MobileNetV2(num_classes=5) # 实例化模型仅有最终类别需要进行设置

    # load pretrain weights
    # download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
    model_weight_path = "./mobilenet_v2.pth"
    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
    pre_weights = torch.load(model_weight_path, map_location='cpu')

    # delete classifier weights
    pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
    missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

    # freeze features weights
    for param in net.features.parameters():
        param.requires_grad = False

    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    best_acc = 0.0
    save_path = './MobileNetV2.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

四、总结

论文部分介绍的是mobilenet V1代码部分则是V2下一章节将对这差异部分进行详细的分析,及其模型核心代码的改变进行详细的指出,加油加油,明天就发。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2147676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

处理RabbitMQ连接和认证问题

在使用RabbitMQ进行消息队列管理时,我们可能会遇到各种连接和认证问题。本文将介绍如何诊断和解决这些问题,并通过使用RabbitMQ的管理端进行登录验证来确保配置正确。 1. 问题概述 在最近的一次部署中,我们遇到了两个主要问题: …

一对一,表的设计

表很大,比如用户 用户登录只需要部分数据,所以把用户表拆成两个表 用户登录表 用户信息表 一对一设计有两种方案: 加外键,唯一 主键共享

学生考试成绩老师发布平台

老师们一直肩负着传授知识与评估学生学习成果的双重责任。其中,发布学生考试成绩是教学过程中不可或缺的一环。然而,传统的成绩发布方式往往繁琐且耗时。老师们需要手动整理成绩,然后通过电话、短信或电子邮件逐一通知学生和家长,…

Jenkins设置自动拉取代码后怎么设置自动执行构建任务?

在 Jenkins 中设置自动拉取代码后,可以通过以下步骤设置自动执行构建任务: 一、配置构建触发器 打开已经设置好自动拉取代码的 Jenkins 任务。在 “构建触发器” 部分,除了 “Poll SCM”(用于定时检查代码仓库更新)外…

Mybatis 和 数据库连接

第一次要下载驱动 查询数据库版本 但是在idea查看数据库我不行,插件我也装了,然后我在尝试改版本。也不行。 爆错 感觉还是插件的问题。先不弄了,影响不大。 但是加载了这个,能在idea写sql语句,还能有提示。

【IPOL阅读】点云双边滤波

文章目录 简介点云滤波处理结果 简介 IPOL,即Image Processing On Line,理论上是一个期刊,但影响因子很低,只是个SCIE,按理说没什么参考价值。但是,这个网站的所有文章,都附带了源代码和演示窗…

【三步搭建 本地 编程助手 codegeex】

这里写目录标题 第一步 ollama安装常见报错 第二步 下载启动模型下载启动模型常见问题 第三步配置codegeex安装插件本地配置 其他 如果可以联网,vscode装个codegeex插件即可,本次搭建的本地编程助手,解决因安全问题完全无网络的情况下的编程助…

诗文发布模板(python代码打造键盘录入诗文自动排版,MarkDown源码文本)

python最好用的f-string,少量代码打造键盘录入诗文自动排版。 (笔记模板由python脚本于2024年09月19日 19:11:50创建,本篇笔记适合喜欢写诗的pythoner的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free&am…

新手入门大模型教程(非常详细)零基础入门到精通,收藏这一篇就够了

目前大模型非常的火,国内开始流行大模型应用,那么作为程序员对于大模型有什么要了解和学习的我们今天就来研究下。 深度学习基础 因为大模型也是人工智能,人工智能就要先学习一下深度学习,深度学习是机器学习领域中的一个方向。…

Linux通过yum安装Docker

目录 一、安装环境 1.1. 旧的docker包卸载 1.2. 安装常规环境包 1.3. 设置存储库 二、安装Docker社区版 三、解决拉取镜像失败 3.1. 创建文件目录/etc/docker 3.2. 写入镜像配置 https://docs.docker.com/engine/install/centos/ 检测操作系统版本,我操作的…

英飞凌最新AURIX™TC4x芯片介绍

概述: 英飞凌推出最新的AURIX™TC4x系列,突破了电动汽车、ADAS、汽车e/e架构和边缘应用人工智能(AI)的界限。这一代面向未来的微控制器将有助于克服安全可靠的处理性能和效率方面的限制。客户将可缩短快速上市时间并降低整体系统成本。为何它被称为汽车市场新出现的主要颠覆…

SourceTree保姆级教程1:(克隆,提交,推送)

本人认为sourceTree 是最好用的版本管理工具,下面将讲解下sourceTree 客户端工具 克隆,提交,推送 具体使用过程,废话不多说直接上图。 使用步骤: 首先必须要先安装Git和sourceTree,如何按照参考其它文章&…

计算机网络:概述 --- 体系结构

目录 一. 体系结构总览 1.1 OSI七层协议体系结构 1.2 TCP/IP四层(或五层)模型结构 二. 数据传输过程 2.1 同网段传输 2.2 跨网段传输 三. 体系结构相关概念 3.1 实体 3.2 协议 3.3 服务 这里我们专门来讲一下计算机网络中的体系结构。其实我们之前…

力扣1143-最长公共子序列(Java详细题解)

题目链接:1143. 最长公共子序列 - 力扣(LeetCode) 前情提要: 如果你做过718. 最长重复子数组 - 力扣(LeetCode)并且看过我的这篇题解力扣718-最长重复子数组(Java详细题解)-CSDN博…

大数据新视界 --大数据大厂之SaaS模式下的大数据应用:创新与变革

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

网站在线客服插件配置

使用工具:百度爱番番 下载地址: 百度爱番番—企业的一站式智能营销管家 一、下载百度爱番番APP,注册账号 二、 登录app 三、点击设置——站点设置——新建站点 四、设置站点名称——站点地址——PC站点——确定 五、点击配置好的站点的获取代…

Linux新增用户,对用户提权

文章目录 一、创建用户二、删除用户三、对用户进行提权 一、创建用户 adduser进行创建用户,名字最好不用和指令名称相同。 在创建完用户时最好使用sudo passwd username进行对用户密码的修改. 二、删除用户 userdel进行对用户的删除 三、对用户进行提权 新建用…

电商好用的客服话术

在电商交易中,良好的客户服务至关重要。优质的售前服务能够帮助顾客更好地了解商品,做出明智的购买决策;而高效的售后服务则能提升顾客的满意度和忠诚度。今天给大家分享了一些好用的客服售前售后话术。 一、售前 “质量好”相关话术:亲亲&a…

工厂ERP采购管理,销售管理,仓库管理,财务管理,生产加工管理建设方案和源码实现(JAVA)

工厂进销存管理系统是一个集采购管理、仓库管理、生产管理和销售管理于一体的综合解决方案。该系统旨在帮助企业优化流程、提高效率、降低成本,并实时掌握各环节的运营状况。 在采购管理方面,系统能够处理采购订单、供应商管理和采购入库等流程&#xff…

基于SSM的宿舍管理系统的设计与实现 (含源码+sql+视频导入教程+文档+PPT)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM的宿舍管理系统9拥有两种角色:管理员和用户 管理员:宿舍管理、学生管理、水电费管理、报修管理、访客管理、各种信息统计报表 用户:个人信息管…