记录一次使用卷积神经网络进行图片二分类的实战

news2024/9/22 0:54:10

写在前面

笔者目前就读的专业是软件工程,并非人工智能专业,但是由于对人工智能有兴趣,于是课下进行了一些自学。正巧最近有些闲暇时间,就想着使用自学的内容做个小型的实战。这篇文章的主要目的也就是从一个入门者的角度,去记录一下这整个流程,顺便也分享一下自己的心得体会。(这里就假定读者都是有一定的深度学习基础了,一些简单的概念,例如k折交叉验证,就不再具体阐述了)
参考资料采用的是沐神的《动手深度学习》,谷歌一下就能找到这本书的pdf版本

需求分析

在b站上有很多视频,每个视频都会有一个封面,笔者希望能训练出一个模型,这个模型可以帮我们在一大堆视频中找到关于猫的视频。具体而言,就是将封面图片作为输入,将封面中含有猫的确信度作为输出,如果确信度大于50%就认为这个封面里面有猫,从而判断这个视频是和猫相关的。这就相当于一个另类的内容推荐系统了,从一大堆视频中找到自己感兴趣的视频,就是这个模型的实用价值。

准备数据集

数据集是深度学习实战中非常重要而且困难的一个环节,通常在教学中,我们都是使用现成的数据集,来进行模型的训练,例如在CNN教学中最经典的数据集Fashion-MNIST。真正开始自己DIY数据集的时候,摆在我们面前的难题就有2个:

  1. 如何选择负例数据?
    这个问题确实困扰了作为入门者的笔者一段时间,因为正例非常好找,只需要找一大堆包含猫的图片即可,但是负例呢?只需要不包含猫的图片都行吗?如果是的话,那一大堆纯色图也可以作为负例吗?
    在使用教学用的数据集,例如Fashion-MNIST时,我们并没有思考这个问题,因为Fashion-MNIST中的数据只有10个特征都很鲜明的类别,例如,如果我们想要判断一张图是不是衬衫,表面上我们做的事是训练模型来得出 这张图是衬衫 和 这张图不是衬衫 的可信度,但实际上我们做的事是,训练模型来得出 这张图是衬衫 和 这张图是大衣,凉鞋,…等另外9个类别物品 的可信度,但如果按照这样的思路,那我们在找负例的时候,是不是需要穷举所有图片中不包含猫的情况?例如一些包含狗的图片,一些包含汽车的图片…,如果是这样的话,那就不具备可操作性了,因为工作量太大了。
    笔者翻了大量资料,也没有这方面的回答,如果读者对这个问题有所了解,还恳请在评论区赐教。
    最终笔者采取了一个比较妥协的方法,就是在b站里面随机找视频,然后保存封面作为负例,在获取了一定量的封面后,笔者再进行人工筛选,剔除掉包含猫的图片。

  2. 如何获取到大量的图片?
    这一点就比较简单了,只需要使用python爬虫,来爬取b站视频封面即可,具体代码由于可能涉及到版权问题,就不放出来了。
    需要注意的一点是,这一步还需要在下载图片后对图片进行缩放,因为之后输入模型的图片的大小都是统一的大小,这里笔者设定的大小是240x240。
    处理图片可以使用opencv-python库,安装方法为pip install opencv-python,图片的下载和缩放可以用下面的代码解决(可以稍微注意一下,opencv中读入的图片是以numpy中的ndarray的形式保存的)

def save_img(url,path):
    #Download image
    res = requests.get(url)
    img = res.content
    with open(path, "wb") as f:
        f.write(img)
    print("Downloaded " + path)

    #Resize image
    img = cv2.imread(path)
    img = cv2.resize(img, dsize=(240, 240), fx=1, fy=1, interpolation=cv2.INTER_LINEAR)
    cv2.imwrite(path, img, [cv2.IMWRITE_JPEG_QUALITY, 100])
    print("Processed " + path)

搭建模型

接下来就正式的写深度学习代码了,笔者使用的框架是pytorch
最开始搭建的模型的网络结构参考了著名的LeNet,也就是最开始使用卷积神经网络进行手写数字识别的网络,具体网络结构如下
在这里插入图片描述
激活函数均使用的ReLU,代码表现为

def get_net():
    return nn.Sequential(
        nn.Conv2d(3, 6, kernel_size=45), nn.ReLU(),
        nn.AvgPool2d(kernel_size=18, stride=2),
        nn.Conv2d(6, 9, kernel_size=20), nn.ReLU(),
        nn.AvgPool2d(kernel_size=9, stride=2),
        nn.Flatten(),
        nn.Linear(9216, 544), nn.ReLU(),
        nn.Linear(544, 68),nn.ReLU(),
        nn.Linear(68, 2), nn.Softmax()
    ).to(device)

整个项目的完整代码如下

import random

import torch
from torch import nn

from matplotlib import pyplot as plt
import numpy as np
import cv2
import os

from my_dataset import get_dataloader

device="cuda"
img_size=240


def read_img_to_numpy(img_path):
    img = cv2.imread(img_path)/128  #为了进行归一化
    img = np.concatenate(
        (img[:, :, 0].reshape((1, img_size, img_size)), img[:, :, 1].reshape((1, img_size, img_size)), img[:, :, 2].reshape((1, img_size, img_size))),
        axis=0)
    return img


def read_all_img():
    positive_dir="samples/positive/"
    negative_dir="samples/negative/"

    features=None
    labels=[]

    names=[positive_dir+name for name in os.listdir(positive_dir)]+[negative_dir+name for name in os.listdir(negative_dir)]
    indexes=list(range(len(names)))
    random.shuffle(indexes)

    for index in indexes:
        name=names[index]
        if positive_dir in name:
            label=1
        else:
            label=0

        labels.append(label)
        img = read_img_to_numpy(name).reshape(1, 3, img_size, img_size)
        img = torch.tensor(img,dtype=torch.float32, device=device)
        if features is None:
            features=img
        else:
            features=torch.concat((features,img))

    labels=torch.tensor(labels,dtype=torch.int64,device=device)

    return features,labels


def get_net():
    return nn.Sequential(
        nn.Conv2d(3, 6, kernel_size=45), nn.ReLU(),
        nn.AvgPool2d(kernel_size=18, stride=2),
        nn.Conv2d(6, 9, kernel_size=20), nn.ReLU(),
        nn.AvgPool2d(kernel_size=9, stride=2),
        nn.Flatten(),
        nn.Linear(9216, 544), nn.ReLU(),
        nn.Linear(544, 68),nn.ReLU(),
        nn.Linear(68, 2), nn.Softmax()
    ).to(device)


def eval_accuracy(net,test_iter):
    total=0
    accurate=0
    for X,y in test_iter:
        y_hat=net(X).argmax(axis=1)

        e=(y==y_hat)
        accurate+=e.sum()
        total+=len(X)

    return accurate/total


def train_for_k_fold(net,train_iter,test_iter,lr,epochs,fold):
    def init_weights(m):
        if type(m)==nn.Linear or type(m)==nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    optimizer=torch.optim.Adam(net.parameters(),lr=lr)
    loss=nn.CrossEntropyLoss()
    loss_record=[]
    accuracy_record=[]
    for epoch in range(epochs):
        net.train()
        for X,y in train_iter:
            optimizer.zero_grad()
            y_hat=net(X)
            l=loss(y_hat,y)
            l.backward()
            optimizer.step()
        loss_record.append(l.to("cpu").detach().numpy())
        accuracy=eval_accuracy(net,test_iter)
        accuracy_record.append(accuracy.to("cpu").detach().numpy())

        print(f"Epoch {epoch}, loss {l}, accuracy {accuracy}")
    #plot
    print(f"Loss {loss_record}")
    print(f"Accuracy {accuracy_record}")
    epoch=np.arange(len(loss_record))
    plt.title(f"fold {fold} lr {lr}, epochs {epochs}")
    plt.plot(epoch,np.array(loss_record),label="Loss")
    plt.plot(epoch,np.array(accuracy_record),label="Accuracy")
    plt.legend()
    plt.savefig(f"fold_{fold}_lr_{lr}_epochs_{epochs}.png")


def k_fold(k=4,lr=0.9,epochs=10,batch_size=3):
    features, labels = read_all_img()
    total_len = len(features)
    fold_len = int(total_len / k)

    for fold in range(k):
        print(f"Start fold {fold}")
        test_start = fold_len * fold
        test_end = min((fold + 1) * fold_len, total_len)

        test_features = features[test_start:test_end]
        test_labels = labels[test_start:test_end]

        train_features = torch.concat((features[:test_start], features[test_end:]))
        train_labels = torch.concat((labels[:test_start], labels[test_end:]))

        train_iter=get_dataloader(train_features,train_labels,batch_size)
        test_iter=get_dataloader(test_features,test_labels,batch_size)

        net = get_net()
        train_for_k_fold(net, train_iter, test_iter, lr, epochs,fold)

        #save net
        torch.save(net.state_dict(),f"fold_{fold}.params")


k_fold(lr=1,epochs=100,batch_size=10)

其中,my_dataset.py这个工具文件的代码如下

import torch
from torch.utils import data


class ArrayDataset(data.Dataset):
    def __init__(self,features,labels):
        self.features=features
        self.labels=labels

    def __getitem__(self, item):
        return self.features[item],self.labels[item]


    def __len__(self):
        return len(self.features)


def get_dataloader(features,labels,batch_size,device="cuda"):
    if not torch.is_tensor(features):
        features=torch.tensor(features,dtype=torch.float32,device=device)
        labels=torch.tensor(labels,dtype=torch.int64,device=device)

    return data.DataLoader(ArrayDataset(features,labels),batch_size,shuffle=True)

上面的代码实现的功能主要是做k折交叉验证来帮助我们寻找合适的超参数,正式的训练代码只需要在这基础上简单的改动即可,这里就不放出来了

训练模型

在训练模型阶段,笔者遇到了各种各样的问题,这也是整个过程中最折腾的部分

问题1:显卡显存不足

之前我们使用的模型和数据集都非常简单,所以2G的GTX1050完全能够胜任,但如今数据量是之前的几十倍,一运行起来,数据还没有全部加载完成,就报了显存不足,如下图所示
在这里插入图片描述
这个时候就只能去网上租借GPU服务器了,笔者找到个平台,注册就送10元代金券,目前总共花费了10.3元,其中10元还是平台提供的代金券抵扣的,而且这个平台的计费方式是按使用时长计费,也就是开机时计费,关机后数据会保留,但是不会计费,这对于我们这种只需要短时间使用GPU的学生是非常划算的。为了避免有打广告的嫌疑,平台具体的名字就不放出来了,如果大家有兴趣可以私信笔者
至此,显存不足的问题算是解决了,写到这里,有了一点感悟,这也是沐神在书中提到的一点,早期人工智能领域的发展速度没有当今快,有一个很大的原因就是硬件资源不足。在2000年左右,可能一块2GB显存的显卡都是很贵的硬件,但是后面跑AlexNet至少也需要20GB的显存,所以可见在早期时候,像AlexNet之类的网络,即使能够实现,也很难展开运算

问题2:模型对所有样本的输出相同

相信这是初学者都会碰到的情况,就是无论输入是什么,模型最终的输出都是一样的
具体回到我们这个实战,在训练过程中,我发现每一个epoch的准确率都是相等的,这显然很异常,所以我修改了一下代码,在eval_accuracy函数处做了如下修改

        # y_hat=net(X).argmax(axis=1)
        y_hat=net(X)
        print(y_hat)
        y_hat=y_hat.argmax(axis=1)

这段代码的目的就是为了打印出y_hat,也就是输入为正例和负例的确信度,也就是softmax层的输出
结果输出很奇怪,所有的输出都是一样的,就像下面这张图中的结果
有了上次处理Dead ReLU的经验,这次可以基本确定是学习率太大了,所以尝试把学习率调到1e-3,但是还是一样的结果,最后把学习率调到1e-8,终于不再是清一色的0和1了,但是收敛速度还是略慢,所以继续调高学习率,最终确定为1e-5是个比较合适的值,可以看到,调为1e-5后,输出就正常了
在这里插入图片描述
从这里可以总结出的一个经验就是,遇到输出全部一样的情况,尽管把学习率往小调,直到输出正常,然后再逐步往大调,来提高学习速率,同时需要先解决欠拟合,再考虑解决过拟合

训练结果

最终k折交叉验证的结果如下图(就只取第一个fold的结果展示了,另外3个fold也是类似的)
在这里插入图片描述
最终测试下来,准确率在80%左右,这已经基本达到我的预期了,因为至少这是一个能用的模型了

模型改进

后面笔者也尝试通过更改优化器,学习率,学习周期等方法来增加准确率,但是都没能成功
推测一方面是数据集不够多,例如一张图片里的内容是一个楼梯,但是反例中并没有出现楼梯,所以就会造成反例的确信度降低
另一方面,这个可能是这个网络结构的上限就是如此了,所以我们或许可以通过更改网络结构来提高准确率
其中扩增数据集的工作量较大,而且效果不显著,所以就不再考虑了,这里主要尝试通过修改网络结构来提高准确率
新的网络结构参考了沐神书中提到的AlexNet
在这里插入图片描述
具体代码实现为,将get_net函数修改为如下

def get_net():
    return nn.Sequential(
        nn.Conv2d(3, 96, kernel_size=11,stride=4,padding=1), nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2),
        nn.Conv2d(96, 256, kernel_size=5,padding=2), nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2),
        nn.Conv2d(256,384,kernel_size=3,padding=1),nn.ReLU(),
        nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
        nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
        nn.MaxPool2d(kernel_size=3,stride=2),
        nn.Flatten(),
        nn.Linear(9216, 4096), nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(4096, 4096), nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(4096, 2), nn.Softmax()
    ).to(device)

可以看到,这个网络结构更为复杂了,所以需要的显存也更大了,在实际测试中,10GB的RTX3080已经顶不住了,需要换上24GB的RTX3090来进行训练,最终训练出的模型在测试数据集中的表现达到了87%
其实这并不奇怪,更多的网络参数就意味着这个模型的拟合能力更强,但同时过拟合的可能性也越大,这也是需要我们去权衡的一个trade off

写在最后

这是笔者第一次进行人工智能方向的实战,能做出成果,自然是非常开心的。在这过程中,最深的感悟有两点:

  1. 数据为王
    虽然深度学习不需要像早起机器学习一样,去精心处理数据,但是大量且合理的数据仍然是有必要的,还记得在kaggle上做过一个数字识别的题目,当时kaggle给出的训练数据集中包含了40000个,即使是使用比较简单的LeNet,也能做到91%的识别准确率,所以可见,如果增大数据量,这个模型是有进步空间的
  2. 硬件决定上限
    前面也提到过了,在人工智能发展的早期,硬件资源制约了这个领域的发展。在那个显卡显存可能只有512MB的时代,如果要跑AlexNet这种光是模型参数都有221MB的模型,应该是非常困难的

最后,这篇文章中可能有一些错误或者描述不清的地方,欢迎大家在评论区里批评指正,也欢迎大家在评论区进行探讨与交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/99172.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++】list

本期就来讲讲list的使用技巧 文章目录list的介绍及使用list的介绍list迭代器失效list的模拟实现list与vector的对比我们前面知道迭代器是一个像指针一样的东西,但是在C里面,出来string和vector,其他类都不能 将迭代器当成指针使用&#xff0c…

二叉树的非递归与相关oj

🧸🧸🧸各位大佬大家好,我是猪皮兄弟🧸🧸🧸 文章目录一、二叉树相关oj①二叉搜索树与双向链表②前序遍历和中序遍历构造二叉树二、二叉树的非递归①前序遍历非递归②中序遍历非递归③后序遍历非…

简单的算法思想 - 利用快慢指针解决问题 - 寻找链表中的中间节点,回文序列,倒数第k个节点 - 详解

文章目录1. 寻找链表中倒数第K个节点1.1. 思路分析1.2 代码实现2. 寻找链表中的中间结点2.1 思路概述2.2 代码实现3. 链表的回文结构3.1 思路分析3.2 代码实现总结✨✨✨学习的道路很枯燥,希望我们能并肩走下来! 本文通过寻找链表中的中间节点&#xff0…

汽车托运网址

开发工具(eclipse/idea/vscode等): 数据库(sqlite/mysql/sqlserver等): 功能模块(请用文字描述,至少200字): 基于Web的汽车托运网站的设计与实现 网站前台:关于我们、联系我们、公告信息、卡车类型、卡车信息、运输评论…

【语音处理】一种增强的隐写及其在IP语音隐写中的应用(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

Effective Objective-C 2.0学习记录(一)

48.多用块枚举,少用for循环for循环快速枚举(快速遍历)基于块的遍历方式在编程中经常需要用到列举collection(NSArray、NSDictionary、NSSet等)中的元素,当前的Objective-C语言有多种办法实现此功能&#xf…

【专栏】核心篇09| 怎么保证缓存与DB的数据一致性

计算机类PDF整理:【详细!!】计算机类PDF整理 Redis专栏合集 【专栏】01| Redis夜的第一章 【专栏】基础篇02| Redis 旁路缓存的价值 【专栏】基础篇03| Redis 花样的数据结构 【专栏】基础篇04| Redis 该怎么保证数据不丢失(上…

Python -- 模块和包

目录 1.Python中的模块 1.1 import 1.3 from...import * 1.4 as别名 2.常见的系统模块和使用 2.1 OS模块 2.2 sys模块 2.3 math模块 2.4 random模块 2.5 datetime模块 2.6 time模块 2.7 calendar模块 2.8 hashlib模块 2.9 hmac模块 2.10 copy模块 3.pip命令的使…

【机器学习---01】机器学习

文章目录1. 什么是机器学习?2. 机器学习分类2.1 基本分类2.2 按模型分类2.3 其他分类(不重要)3. 机器学习三要素4. 监督学习的应用(分类、标注、回归问题)1. 什么是机器学习? 定义:给定训练集D,让计算机从一个函数集合F {f1(x)&…

虚拟机打不开,提示“此主机不支持虚拟化实际模式”的详细解决方法

虚拟机打不开,提示“此主机不支持虚拟机实际模式”的解决方法 一、第一种情况安装/启动虚拟机失败, 在VMWare软件中,安装/启动虚拟机时,如果出以类似以下的错误提示: 出现该提示是由于电脑不支持虚拟化技术或是相关功…

IDEA报错:类文件具有错误的版本 61.0,应为52.0

springboot项目启动报错: 类文件具有错误的版本 61.0,应为52.0 请删除该文件或确保该文件位于正确的类路径子目录中 查阅了网上的很多资料,普遍原因说是springboot版本过高,高于3.0 需要在pom文件中降低版本 也有说是idea的maven配置java版…

网购商城网站

开发工具(eclipse/idea/vscode等): 数据库(sqlite/mysql/sqlserver等): 功能模块(请用文字描述,至少200字):

【Python机器学习】层次聚类AGNES、二分K-Means算法的讲解及实战演示(图文解释 附源码)

需要源码和数据集请点赞关注收藏后评论区留言私信~~~ 层次聚类 在聚类算法中,有一类研究执行过程的算法,它们以其他聚类算法为基础,通过不同的运用方式试图达到提高效率,避免局部最优等目的,这类算法主要有网格聚类和…

easypoi导入excel空指针异常

问题描述 前端页面停留在导入页面,通过后端返回的接口,确认后端已经抛出异常查看系统调用错误日志为 java.lang.NullPointerException: nullat org.apache.poi.xssf.usermodel.XSSFClientAnchor.setCol2(XSSFClientAnchor.java:231)at org.apache.poi.…

基于EKF的四旋翼无人机姿态估计matlab仿真

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 卡尔曼滤波是一种高效率的递归滤波器(自回归滤波器),它能够从一系列的不完全包含噪声的测量中,估计动态系统的状态。这种滤波方法以它的发明者鲁道夫E卡尔曼(R…

Android koin

1.源码地址 1.源码地址 2.作用 1.让代码看起来更简洁 现在是这样创建对象的 2.解耦 我们有一个类,然后有100个地方使用它,这个时候如果我们要修改构造参数,加入一个参数,那么我们就要修改100个地方;如果过了一个…

怎样让chatGPT给你打工然后月入过千?

前言 chatGPT最近火出圈了,怎么薅一个文字模型给你打工呢? 这个UP给了个思路:哔哩哔哩 emmm有点尴尬,可能是热度比较高,b站的视频作者自己下架了。 总结一下: 薅的对象百度文库创作中心:地址…

设计模式之装饰器模式

decorator design pattern 装饰模式的概念、装饰模式的结构、装饰模式的优缺点、装饰模式的使用场景、装饰模式与代理模式的区别、装饰模式的实现示例、装饰模式的源码分析 1、装饰模式的概念 装饰模式,即在不改变现有对象结构的前提下,动态的给对象增加…

【云原生】Grafana 介绍与实战操作

文章目录一、概述二、Grafana 安装1)下载安装2)安装包信息3)启动服务4)Grafana 访问三、Grafana 功能介绍四、使用mysql存储1)安装mysql2)修改grafana配置1、创建grafana用户和grafana库2、修改grafana配置…

[附源码]Python计算机毕业设计Django学分制环境下本科生学业预警帮扶系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…