深度学习入门之ResNet食物图像分类

news2024/9/29 13:19:39

前言

参加了华为一个小比赛第四届MindCon-爱(AI)美食–10类常见美食图片分类,本来想实践机器学习课程的知识,后来发现图像分类任务基本都是用神经网络做,之前在兴趣课上学过一点神经网络但不多,通过这样一个完整的项目也算入门了。
代码仓库:https://github.com/fgmn/ResNet

任务

在这里插入图片描述

在这里插入图片描述

ResNet

这里主要结合官方pytorch代码和B站视频6.2 使用pytorch搭建ResNet并基于迁移学习训练进行理解。

模型

层数不同的网络许多子结构是相似的,因此对子结构的定义会有一些参数定义。
在这里插入图片描述
论文提到两种残差结构,从上面表格可以看到,左侧building block用于18,34层网络,右侧bottleneck用于50,101,152层网络。
在这里插入图片描述
左侧残差结构的实现如下,首先定义残差结构所使用的一系列层结构,stride=1时输入输出矩阵大小相同,stride=2时输出长宽均为输入的 1 2 \frac{1}{2} 21channel是通道数,和卷积核个数对应,如 3 × 3 , 64 3\times3,64 3×3,64代表使用64个大小为 3 × 3 3\times3 3×3的卷积核对输入的64个通道进行卷积运算。
之后定义正向传播过程,实际定义了网络结构,bn层定义在卷积层和激活函数之间。

class BasicBlock(nn.Module):
    expansion = 1   # 用于协调残差结构卷积核个数发生变化

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        # 输入参数:输入特征矩阵深度,输出特征矩阵深度,卷积核移动步长,下采样参数(对应虚线残差结构)
        # 定义残差结构所使用的一系列层结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        # 正向传播过程
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

同理,定义右侧残差结构,之后定义ResNet如下:

class ResNet(nn.Module):

    def __init__(self,
                 block,         # 残差结构:BasicBlock/Bottleneck
                 blocks_num,    # 残差结构数
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)

        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        # 构建层结构
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        # 虚线残差结构
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        # 实线残差结构
        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)  # 全连接层

        return x

定义具体网络:

def resnet34(num_classes=1000, include_top=True):
    # 迁移学习,预训练模型路径
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
    
# more ...

训练

尝试使用cuda,

    # 指定训练使用设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

定义数据的transform,进行随机裁剪,随机翻转,标准化处理等等操作,

    # 图像处理
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪
                                     transforms.RandomHorizontalFlip(),  # 随机翻转
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),  # 标准化处理
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

加载训练集以及验证集,施加transform,定义batch_size=16

    # ---------------------------- 数据集加载----------------------------------------
    data_root = os.path.abspath(os.path.join(os.getcwd()))  # "../.."返回上上层目录 get data root path
    image_path = os.path.join(data_root, "data_set", "food_data")  # food data set path

    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    food_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in food_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=10)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)  # 加载数据使用线程个数

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

实例化一个34层网络,加载预训练模型,基于迁移学习方法训练,

    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))  # 载入模型权重

定义全连接层,交叉熵损失函数,以及Adam优化器,

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 10)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()  # 针对多类别,损失交叉熵函数

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

训练并验证,保存效果最好的网络,

   epochs = 20
    best_acc = 0.0
    save_path = './resNet34.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()  # 可管理batchnorm层以及dropout方法
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()  # 清空之前的梯度信息
            logits = net(images.to(device))  # 正向传播
            loss = loss_function(logits, labels.to(device))
            loss.backward()  # 反向传播
            optimizer.step()  # 更新每个节点参数

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():  # 在验证过程中不计算损失梯度
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()  # to(device)在设备中可能有缓存

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:  # 保存最优网络
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/150168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android UI界面刷新机制

一 前言 作为严重影响 Android 口碑问题之一的 UI 流畅性差的问题,首先在 Android 4.1 版本中得到了有效处理。其解决方法即在 4.1 版本推出的 Project Butter。Project Butter 对 Android Display系统进行了重构,引入三个核心元素:VSYNC、T…

nmake文件学习记录(一)看《跟我一起写Makefile》

1、陈皓《跟我一起写Makefile》 makefile 带来的好处就是——“自动化编译”,一旦写好,只需要一个make 命令,整个工程完全自动编译,极大的提高了软件开发的效率。 make 是一个命令工具,是一个解释makefile 中指令的命…

线程池(ThreadPoolExecutor)

文章目录一、线程池标准库提供的线程池ThreadPoolExecutor自定义线程池一、线程池 为什么要引入线程池? 这个原因我们需要追溯到线程,我们线程存在的意义在于,使用进程进行并发编程太重了,所以引入了线程,因为线程又称为 “轻量…

【知识图谱导论-浙大】第三、四章:知识图谱的抽取与构建

前文: 【知识图谱导论-浙大】第一章:知识图谱概论 【知识图谱导论-浙大】第二章:知识图谱的表示 说明:原视频中的第三章主要介绍了图数据库相关的内容,有兴趣的可以查看相关课件或者对应的视频: 【知识图…

[Linux理论基础1]----手写和使用json完成[序列化和反序列化]

文章目录前言一、应用层二、再谈"协议"三、 网络版计算器手写版本使用第三方库json实现完整代码总结前言 理解应用层的作用,初始HTTP协议;理解传输层的作用,深入理解TCP的各项特性和机制;对整个TCP/IP协议有系统的理解;对TCP/IP协议体系下的其他重要协议和技术有一定…

JPG格式如何转为PDF格式?快来学习如何转换

图片是我们经常用到的一种便携式文件,像我们日常的照片或者是一些学习资料、工作资料都是图片形式的,我们经常会把这些图片发送给其他人,这时候就需要想一个简单的办法把图片一次性发送过去,所以我们可以将图片转换为PDF文件&…

暨 广告、推荐、搜索 三大顶级复杂业务之 “广告业务系统详叙”

文章目录暨 广告、推荐、搜索 三大顶级复杂业务之 “广告业务系统详叙”广告系统的核心功能ADX 架构流程概述典型 ADX 架构图概述消息中心抱歉,有段日子没码字了,后面会尽量补出来分享给大家。这段时间整理了关于 “广告业务” 相关的思考,作…

OSPF笔记(五):OSPF虚链路--普通区域远离骨干区域

一、OSPF 虚链路 1.1 虚链路邻居关系: hello包只发送一次,没有dead时间 虚链路配置邻居指的是RID,非接口IP 1.2 虚链路解决的问题: 普通区域远离骨干区域0的问题 普通区域连接两个骨干区域0问题 (1)…

SpringSecurity授权功能快速上手

3. 授权 3.0 权限系统的作用 例如一个学校图书馆的管理系统,如果是普通学生登录就能看到借书还书相关的功能,不可能让他看到并且去使用添加书籍信息,删除书籍信息等功能。但是如果是一个图书馆管理员的账号登录了,应该就能看到并…

最新款发布 | 德州仪器(TI)60G单芯片毫米波雷达芯片 -xWRL6432

本文编辑:调皮哥的小助理 概述 最近,德州仪器(TI)推出了单芯片低功耗 57GHz 至 64GHz 工业(汽车)毫米波雷达传感器IWRL6432,具有 7GHz 的连续带宽,可实现更高分辨率。除了UWB雷达之外,IWRL6432目前是毫米波雷达带宽最…

漏洞挖掘-不安全的HTTP方法

前言: 年关将至,这可能是年前最后一篇文章了。已经有一段时间没有更新文章了,因为最近也没有学到什么新的知识,也就没什么可写的,又不想灌水。最近关注的好兄弟们多了很多,在这里也是十分感谢大家的支持&am…

Make RepVGG Greater Again | 中文翻译

性能和推理速度之间的权衡对于实际应用至关重要。而重参化可以让模型获得了更好的性能平衡,这也促使它正在成为现代卷积神经网络中越来越流行的架构。尽管如此,当需要INT8推断时,其量化性能通常太差,无法部署(例如Imag…

SQL BETWEEN 操作符

BETWEEN 操作符用于选取介于两个值之间的数据范围内的值。 SQL BETWEEN 操作符 BETWEEN 操作符选取介于两个值之间的数据范围内的值。这些值可以是数值、文本或者日期。 SQL BETWEEN 语法 SELECT column1, column2, ... FROM table_name WHERE column BETWEEN value1 AND va…

力扣714题 买卖股票的最佳时机含手续费

class Solution {public int maxProfit(int[] prices, int fee) {// 买第一天股票所需要的全部费用(买入)int buy prices[0] fee; // 利润总和int sum 0;for(int p:prices){// 如果买后些天的股票所需的全部费用比第一天的少,就买后边这天的(买入)if(p fee < buy){buy …

【Python】python深拷贝和浅拷贝(一)

【Python】python深拷贝和浅拷贝&#xff08;一&#xff09; 定义 直接赋值&#xff1a;其实就是对象的引用。浅拷贝&#xff1a;拷贝父对象&#xff0c;不会拷贝对象的内部的子对象。深拷贝&#xff1a; copy 模块的 deepcopy 方法&#xff0c;完全拷贝了父对象及其子对象。…

SpringBoot过滤器与拦截器

为什么要有过滤器和拦截器&#xff1f; 在实际开发过程中&#xff0c;经常会碰见一些比如系统启动初始化信息、统计在线人数、在线用户数、过滤敏高词汇、访问权限控制(URL级别)等业务需求。这些对于业务来说一般上是无关的&#xff0c;业务方是无需关注的&#xff0c;业务只需…

Ubuntu20.04安装ROS Noetic

一、实验环境准备 1.使用系统:Ubuntu20.04&#xff08;安装不做赘述&#xff0c;可看我另外一篇博客Ubuntu20.04安装&#xff09;&#xff0c;可到Ubuntu官网下载https://ubuntu.com/download/desktop 2.配置网络&#xff0c;使其可通互联网 二、在Ubuntu20.04上搭建ROS机器人…

树上差分-LCA

树上差分算法分析&#xff1a;练习例题差分的基本思想详情见博客&#xff08;一维、二维差分&#xff09;&#xff1a; https://blog.csdn.net/weixin_45629285/article/details/111146240 算法分析&#xff1a; 面向的对象可以是树上的结点&#xff0c;也可以是树上的边 结点…

springmvc 文件上传请求转换为MultipartFile的过程

前言: 最近在研究文件上传的问题,所以就写下这个博客,让大家都知道从流转换为MutipartFile的过程,不然你就知道在方法中使用,而不知道是怎么样处理的,是不行的 从DiaspatherServlet说起: 别问为啥,去了解tomcat和servlet的关系,我后面会 写这篇博客的 servlet的生命周期 ini…

[ 数据结构 ] 查找算法--------线性、二分、插值、斐波那契查找

0 前言 查找算法有4种: 线性查找 二分查找/折半查找 插值查找 斐波那契查找 1 线性查找 思路:线性遍历数组元素,与目标值比较,相同则返回下标 /**** param arr 给定数组* param value 目标元素值* return 返回目标元素的下标,没找到返回-1*/public static int search(…